| #include <array> |
| #include <cstdint> |
| #include <functional> |
| #include <memory> |
| #include <numeric> |
| #include <optional> |
| #include <tuple> |
| #include <utility> |
| #include <vector> |
| |
| #include "learning/brain/tpu/runtime/topology/tpu_topology.h" |
| #include "learning/brain/tpu/runtime/topology/tpu_topology_serdes.h" |
| #include "learning/brain/tpu/runtime/tpu_version.h" |
| #include "platforms/xla/mosaic/dialect/llo/llo_dialect.h" |
| #include "platforms/xla/mosaic/dialect/tpu/tpu_passes.h" |
| #include "platforms/xla/service/jellyfish/gain_latch_mode.h" |
| #include "platforms/xla/service/jellyfish/matmul_data_format.h" |
| #include "platforms/xla/service/jellyfish/target.h" |
| #include "platforms/xla/service/jellyfish/transfer_size_util.h" |
| #include "platforms/xla/service/jellyfish/vpack_format.h" |
| #include "third_party/absl/log/check.h" |
| #include "third_party/absl/log/log.h" |
| #include "third_party/absl/types/span.h" |
| #include "third_party/llvm/llvm-project/llvm/include/llvm/Support/MathExtras.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Arith/IR/Arith.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Func/IR/FuncOps.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Math/IR/Math.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/IR/SCF.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/Transforms/Transforms.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/AffineMap.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinAttributes.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinTypeInterfaces.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinTypes.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/PatternMatch.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/ValueRange.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Transforms/DialectConversion.h" |
| #include "third_party/py/jax/jaxlib/mosaic/dialect/tpu/tpu_dialect.h" |
| #include "third_party/tensorflow/compiler/xla/mlir/utils/type_util.h" |
| #include "third_party/tensorflow/compiler/xla/shape_util.h" |
| |
| namespace mlir::tpu { |
| |
| #define GEN_PASS_DECL_LOWERTOLLOPASS |
| #define GEN_PASS_DEF_LOWERTOLLOPASS |
| #include "platforms/xla/mosaic/dialect/tpu/tpu_passes.h.inc" |
| |
| namespace { |
| |
| bool hasMemorySpace(MemRefType ty, tpu::MemorySpace space) { |
| auto memory_space = |
| dyn_cast_or_null<tpu::MemorySpaceAttr>(ty.getMemorySpace()); |
| return memory_space && memory_space.getValue() == space; |
| } |
| |
| bool getConstantIndices(SmallVectorImpl<int64_t> *storage, ValueRange indices) { |
| CHECK(storage->empty()); |
| for (auto ix : indices) { |
| auto ix_cst = ix.getDefiningOp<llo::ConstantOp>(); |
| if (!ix_cst) { |
| return false; |
| } |
| storage->push_back(cast<IntegerAttr>(ix_cst.getValue()).getInt()); |
| } |
| return true; |
| } |
| |
| std::pair<Value, Value> vmemAddrOffsetToAddrDisplacement( |
| PatternRewriter &rewriter, Value addr, Value offset) { |
| if (!offset) { |
| return std::make_pair(addr, nullptr); |
| } |
| if (offset.getDefiningOp<llo::ConstantOp>()) { |
| return std::make_pair(addr, offset); |
| } |
| return std::make_pair( |
| rewriter.create<llo::ScalarAddressVmemOp>(offset.getLoc(), addr, offset) |
| .getResult(), |
| nullptr); |
| } |
| |
| template <typename Op> |
| using PatternFunc = std::function<LogicalResult(Op, typename Op::Adaptor, |
| ConversionPatternRewriter &)>; |
| |
| template <typename Op> |
| void addPattern(RewritePatternSet &patterns, PatternFunc<Op> pattern) { |
| class FuncConversion final : public OpConversionPattern<Op> { |
| public: |
| FuncConversion(PatternFunc<Op> pattern, MLIRContext *context) |
| : OpConversionPattern<Op>(context), pattern_(std::move(pattern)) {} |
| |
| LogicalResult matchAndRewrite( |
| Op op, typename Op::Adaptor adaptor, |
| ConversionPatternRewriter &rewriter) const override { |
| return pattern_(op, adaptor, rewriter); |
| } |
| |
| private: |
| PatternFunc<Op> pattern_; |
| }; |
| patterns.insert(std::make_unique<FuncConversion>(std::move(pattern), |
| patterns.getContext())); |
| } |
| |
| llo::VpackFormat GetCompressedByteFormat(unsigned bitwidth) { |
| switch (bitwidth) { |
| case 16: |
| return llo::VpackFormat::kCompressedB16; |
| case 8: |
| return llo::VpackFormat::kCompressedB8; |
| case 4: |
| return llo::VpackFormat::kCompressedB4; |
| case 2: |
| return llo::VpackFormat::kCompressedB2; |
| case 1: |
| return llo::VpackFormat::kCompressedB1; |
| default: |
| return llo::VpackFormat::kInvalid; |
| } |
| } |
| |
| class MockTpuInstance { |
| public: |
| explicit MockTpuInstance(::tpu::TpuVersion version) { |
| auto maybe_topology = |
| ::tpu::TpuTopologySerdes::Construct({.version = version}); |
| CHECK_OK(maybe_topology.status()); |
| topology_ = std::move(maybe_topology.value()); |
| auto maybe_target = |
| ::xla::jellyfish::Target::CreateFromTopology(topology_.get()); |
| CHECK_OK(maybe_target.status()); |
| target_ = std::move(maybe_target.value()); |
| } |
| |
| ::xla::jellyfish::Target *target() { return target_.get(); } |
| |
| private: |
| std::unique_ptr<const ::tpu::TpuTopology> topology_; |
| std::unique_ptr<::xla::jellyfish::Target> target_; |
| }; |
| |
| class LowerToLLOPass : public impl::LowerToLLOPassBase<LowerToLLOPass> { |
| public: |
| explicit LowerToLLOPass(const xla::jellyfish::Target *target, |
| bool unsafe_allow_multicore_remote_dma) |
| : target_(target), |
| unsafe_allow_multicore_remote_dma_(unsafe_allow_multicore_remote_dma) {} |
| LowerToLLOPass() |
| : target_(nullptr), unsafe_allow_multicore_remote_dma_(false) {} |
| |
| protected: |
| void runOnOperation() override; |
| |
| private: |
| std::unique_ptr<MockTpuInstance> tryCreatingMock() { |
| if (mock_target == 2) { |
| return std::make_unique<MockTpuInstance>(::tpu::TpuVersion::kJellyfish); |
| } |
| if (mock_target == 3) { |
| return std::make_unique<MockTpuInstance>(::tpu::TpuVersion::kDragonfish); |
| } |
| if (mock_target == 4) { |
| return std::make_unique<MockTpuInstance>(::tpu::TpuVersion::kPufferfish); |
| } |
| return nullptr; |
| } |
| |
| int64_t laneCount() { return target_->LaneCount(); } |
| int64_t sublaneCount() { return target_->SublaneCount(); } |
| |
| bool isRepresentableVectorType(Type ty) { |
| auto vty = dyn_cast<VectorType>(ty); |
| if (!vty) { |
| return false; |
| } |
| auto ety = vty.getElementType(); |
| if (!ety.isF32() && !ety.isBF16() && !ety.isSignlessInteger()) { |
| return false; |
| } |
| if (vty.getRank() == 2) { |
| return ety.getIntOrFloatBitWidth() == 32 && |
| vty.getShape() == ArrayRef<int64_t>{sublaneCount(), laneCount()}; |
| } |
| return vty.getShape() == |
| ArrayRef<int64_t>{sublaneCount(), laneCount(), |
| 32 / ety.getIntOrFloatBitWidth()}; |
| } |
| |
| std::vector<int64_t> nativeVectorShape(unsigned bitwidth) { |
| if (bitwidth == 32) { |
| return {sublaneCount(), laneCount()}; |
| } |
| CHECK_LT(bitwidth, 32); |
| CHECK(llvm::isPowerOf2_32(bitwidth)); |
| return {sublaneCount(), laneCount(), 32 / bitwidth}; |
| } |
| |
| bool isMaskVectorType(Type ty) { |
| auto vty = dyn_cast<VectorType>(ty); |
| if (!vty || !vty.getElementType().isSignlessInteger(1) || |
| vty.getRank() < 2) { |
| return false; |
| } |
| if (vty.getShape().take_front(2) != |
| ArrayRef<int64_t>{sublaneCount(), laneCount()}) { |
| return false; |
| } |
| if (vty.getRank() == 3) { |
| int32_t max_bits = target_->BitsPerVmregLaneAndSublane(); |
| return vty.getDimSize(2) == 1 || |
| (max_bits >= 2 && vty.getDimSize(2) == 2) || |
| (max_bits >= 4 && vty.getDimSize(2) == 4); |
| } |
| return vty.getRank() == 2; |
| } |
| |
| LogicalResult convertMemRef1DTo2D(MemRefType *ty, |
| SmallVector<int64_t, 2> *indices) { |
| CHECK_EQ(ty->getRank(), 1); |
| auto lane_count = laneCount(); |
| auto layout_1d = ty->getLayout().dyn_cast<tpu::TiledLayoutAttr>(); |
| if (!layout_1d) { |
| return failure(); |
| } |
| auto packing = 32 / ty->getElementTypeBitWidth(); |
| if (packing == 1) { |
| if (layout_1d.getTiles().size() != 1) { |
| return failure(); |
| } |
| } else { |
| if (layout_1d.getTiles().size() != 3) { |
| return failure(); |
| } |
| auto second = layout_1d.getTiles()[1].dimensions(); |
| auto third = layout_1d.getTiles()[2].dimensions(); |
| if (second.size() != 1 || second.front() != lane_count || |
| third.size() != 2 || third[0] != packing || third[1] != 1) { |
| return failure(); |
| } |
| } |
| absl::Span<const int64_t> tiling_1d = |
| layout_1d.getTiles().front().dimensions(); |
| if (tiling_1d.size() != 1 || tiling_1d.front() % lane_count != 0 || |
| ty->getDimSize(0) % lane_count != 0) { |
| return failure(); |
| } |
| SmallVector<xla::Tile, 2> tiles; |
| tiles.push_back(::xla::Tile({tiling_1d.front() / lane_count, lane_count})); |
| if (packing != 1) { |
| tiles.push_back(::xla::Tile({packing, 1})); |
| } |
| auto layout_2d = tpu::TiledLayoutAttr::get( |
| ty->getContext(), tiles, {layout_1d.getTileStrides()[0], 1}); |
| int64_t length = ty->getDimSize(0); |
| std::array<int64_t, 2> shape_2d{length / lane_count, lane_count}; |
| *ty = MemRefType::get(shape_2d, ty->getElementType(), layout_2d, |
| ty->getMemorySpace()); |
| indices->push_back(indices->front() % lane_count); |
| indices->front() /= lane_count; |
| return success(); |
| } |
| |
| // The read memory region should be of shape [num_read_sublanes, 128]. |
| std::optional<Value> indicesToVmemOffset(ValueRange indices, |
| int64_t num_read_sublanes, |
| MemRefType ty, |
| PatternRewriter &rewriter, |
| bool consistent_directions = true) { |
| if (!ty.hasStaticShape() || indices.size() != ty.getRank() || |
| ty.getRank() < 1) { |
| return std::nullopt; |
| } |
| |
| CHECK_NE(ty.getRank(), 0); |
| int64_t tiled_dims = ty.getRank() == 1 ? 1 : 2; |
| SmallVector<int64_t, 2> tile_idx; |
| if (!getConstantIndices(&tile_idx, indices.take_back(tiled_dims))) { |
| return std::nullopt; |
| } |
| if (ty.getRank() == 1) { |
| if (convertMemRef1DTo2D(&ty, &tile_idx).failed()) { |
| return std::nullopt; |
| } |
| } |
| // Now we can assume we're working with 2D data. |
| |
| int64_t packing = 32 / ty.getElementTypeBitWidth(); |
| auto layout = dyn_cast<tpu::TiledLayoutAttr>(ty.getLayout()); |
| if (!layout) { |
| return std::nullopt; |
| } |
| auto tiling = layout.getTiles(); |
| if (tiling.empty()) { |
| return std::nullopt; |
| } |
| auto leading_tile = tiling.front().dimensions(); |
| if (leading_tile.size() != 2) { |
| return std::nullopt; |
| } |
| int64_t tile_major = leading_tile[0]; |
| int64_t tile_minor = leading_tile[1]; |
| if (tiling.size() == 1) { |
| if (packing != 1) { |
| return std::nullopt; |
| } |
| } else if (tiling.size() == 2) { |
| auto trailing_tile = tiling.back().dimensions(); |
| if (trailing_tile.size() != 2 || trailing_tile[0] != packing || |
| trailing_tile[1] != 1) { |
| return std::nullopt; |
| } |
| } else { |
| return std::nullopt; |
| } |
| |
| auto batch_strides = layout.getTileStrides().drop_back(2); |
| auto tile_strides = layout.getTileStrides().take_back(2); |
| auto ty_tile_shape = ty.getShape().take_back(2); |
| if (ty_tile_shape[0] % tile_major != 0 || |
| ty_tile_shape[1] % tile_minor != 0) { |
| return std::nullopt; // Memref doesn't have a padded size? |
| } |
| |
| // Compute the tile coordinates. |
| int64_t major_among_tiles = tile_idx[0] / tile_major; |
| int64_t minor_among_tiles = tile_idx[1] / tile_minor; |
| int64_t major_in_tile = tile_idx[0] % tile_major; |
| int64_t minor_in_tile = tile_idx[1] % tile_minor; |
| // We assume tiles are Nx128, so anything else would be unaligned. |
| if (minor_in_tile != 0) { |
| return std::nullopt; |
| } |
| // Make sure the access doesn't access multiple tiles if it shouldn't. |
| // When the minormost dimension matches the minor tile size, moving to the |
| // next sublane always advances the more major dimension. |
| if (consistent_directions && ty_tile_shape[1] != tile_minor && |
| major_in_tile + num_read_sublanes > tile_major) { |
| return std::nullopt; |
| } |
| |
| // Every `packing` rows share the same memory address, so accesses should be |
| // aligned to its multiple. |
| // TODO(apaszke): Check that the tile size is a multiplicity of packing. |
| if (tile_major % packing != 0 || major_in_tile % packing != 0) { |
| return std::nullopt; |
| } |
| |
| int64_t tile_bytes = |
| tile_major * tile_minor * ty.getElementTypeBitWidth() / 8; |
| // Tiles should be at most one chunk in size and there should be a power of |
| // 2 of them in each chunk. |
| if (tile_bytes > target_->ChunkBytes() || |
| tile_bytes % target_->SublaneBytes() != 0 || |
| !llvm::isPowerOf2_64(tile_bytes / target_->SublaneBytes())) { |
| return std::nullopt; |
| } |
| // Displacement is in the unit of sublanes. |
| int64_t displacement_per_tile = tile_bytes / target_->SublaneBytes(); |
| int64_t tiled_tile_lin_idx = major_among_tiles * tile_strides[0] + |
| minor_among_tiles * tile_strides[1]; |
| int64_t tiled_displacement = |
| tiled_tile_lin_idx * displacement_per_tile + (major_in_tile / packing); |
| |
| auto batch_idx = indices.drop_back(tiled_dims); |
| |
| auto s32_cst = [&](int64_t value) -> Value { |
| return rewriter |
| .create<llo::ConstantOp>(rewriter.getUnknownLoc(), |
| rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(value)) |
| .getResult(); |
| }; |
| SmallVector<int64_t> non_tile_idx; |
| // We constant-fold the displacement computation if all indices are static. |
| if (getConstantIndices(&non_tile_idx, indices.drop_back(tiled_dims))) { |
| int64_t batch_lin_idx = 0; |
| CHECK_EQ(non_tile_idx.size(), batch_strides.size()); |
| for (int i = batch_idx.size() - 1; i >= 0; --i) { |
| batch_lin_idx += non_tile_idx[i] * batch_strides[i]; |
| } |
| int64_t total_displacement = |
| batch_lin_idx * displacement_per_tile + tiled_displacement; |
| if (total_displacement == 0) { |
| return Value(); |
| } |
| return s32_cst(total_displacement); |
| } |
| Location loc = rewriter.getUnknownLoc(); |
| Value batch_lin_idx_scaled = s32_cst(0); |
| CHECK_EQ(batch_idx.size(), batch_strides.size()); |
| for (int i = 0; i < batch_idx.size(); ++i) { |
| CHECK(batch_idx[i].getType().isSignlessInteger(32)); |
| batch_lin_idx_scaled = rewriter.create<llo::ScalarAddS32Op>( |
| loc, batch_lin_idx_scaled, |
| // We fold the multiplication by displacement_per_tile into the |
| // constant here, instead of adding another op at the end. |
| rewriter.create<llo::ScalarMulS32Op>( |
| loc, batch_idx[i], |
| s32_cst(batch_strides[i] * displacement_per_tile))); |
| } |
| return rewriter.create<llo::ScalarAddS32Op>(loc, batch_lin_idx_scaled, |
| s32_cst(tiled_displacement)); |
| } |
| |
| // Only used for scalar loads. |
| // Returns the offset to which indices point to. For packed types (<32-bit) |
| // additionally return the sub-element index within that address (or nullptr |
| // if the type is not packed). |
| // TODO(b/299280718): Handle packed 1D memrefs |
| std::pair<Value, Value> indicesToOffset(Location loc, MemRefType ty, |
| ValueRange indices, |
| ConversionPatternRewriter &rewriter) { |
| auto layout = dyn_cast<TiledLayoutAttr>(ty.getLayout()); |
| if (!layout) { |
| return std::make_pair(nullptr, nullptr); |
| } |
| auto tiles = layout.getTiles(); |
| absl::Span<const int64_t> tile = tiles.front().dimensions(); |
| if (tile.size() == 1) { |
| // We can ignore the padding only if there's a single dimension. |
| // We should be able to reuse the code below if ever we need that. |
| if (ty.getRank() > 1 || tiles.size() != 1) { |
| return std::make_pair(nullptr, nullptr); |
| } |
| return std::make_pair(indices.front(), nullptr); |
| } |
| auto int_const = [&](int i) -> Value { |
| return rewriter.create<llo::ConstantOp>(loc, rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(i)); |
| }; |
| auto mod_i = [&](const Value &v, int d) -> Value { |
| return rewriter.create<llo::ScalarBitwiseAndOp>(loc, v, int_const(d - 1)); |
| }; |
| auto div_i = [&](const Value &v, int d) -> Value { |
| return rewriter.create<llo::ScalarDivS32Op>(loc, v, int_const(d)); |
| }; |
| auto mul_i = [&](const Value &v, int d) -> Value { |
| return rewriter.create<llo::ScalarMulS32Op>(loc, v, int_const(d)); |
| }; |
| auto add = [&](const Value &v1, const Value &v2) -> Value { |
| return rewriter.create<llo::ScalarAddS32Op>(loc, v1, v2); |
| }; |
| |
| if (tile.size() == 2) { |
| // Here we assume one smem word is 4 bytes. |
| CHECK_EQ(target_->SmemWordSizeBytes(), 4); |
| unsigned packing = 32 / ty.getElementTypeBitWidth(); |
| if (packing > 1) { |
| if (tiles.size() != 2 || |
| tiles[1].dimensions() != absl::Span<const int64_t>{packing, 1}) { |
| return std::make_pair(nullptr, nullptr); |
| } |
| } |
| int rank = indices.size(); |
| Value tile_offset = int_const(0); |
| auto tile_strides = layout.getTileStrides(); |
| for (int i = rank - 1; i >= 0; --i) { |
| Value index = indices[i]; |
| // Convert element index to tile index. |
| if (i >= rank - 2) { |
| index = div_i(index, tile[i - rank + 2]); |
| } |
| tile_offset = add(tile_offset, mul_i(index, tile_strides[i])); |
| } |
| int64_t tile_bits = tile[0] * tile[1] * ty.getElementTypeBitWidth(); |
| int64_t word_bits = 8 * target_->SmemWordSizeBytes(); |
| CHECK(tile_bits >= word_bits); |
| // Get total smem words from tiles. |
| Value word_offset = mul_i(tile_offset, tile_bits / word_bits); |
| // Calculate the extra words offset in tile when indices are not aligned. |
| Value part = packing > 1 ? mod_i(indices[rank - 2], packing) : nullptr; |
| auto major_in_tile = mod_i(indices[rank - 2], tile[0]); |
| auto minor_in_tile = mod_i(indices[rank - 1], tile[1]); |
| // word_offset += (major_in_tile / packing) * tile[1] + minor_in_tile |
| word_offset = add( |
| word_offset, |
| add(mul_i(div_i(major_in_tile, packing), tile[1]), minor_in_tile)); |
| return std::make_pair(word_offset, part); |
| } |
| return std::make_pair(nullptr, nullptr); |
| } |
| |
| std::tuple<uint32_t, int64_t, bool> encodeSublaneMask( |
| ArrayRef<bool> sublane_mask) { |
| // TODO(apaszke): Verify length. |
| uint32_t sublane_mask_i32 = 0; |
| int64_t num_read_sublanes = 0; |
| bool all_sublanes_read = true; |
| for (uint32_t i = 0; i < sublaneCount(); ++i) { |
| if (!sublane_mask[i]) { |
| all_sublanes_read = false; |
| continue; |
| } |
| sublane_mask_i32 |= uint32_t{1} << i; |
| num_read_sublanes = i; |
| } |
| return std::make_tuple(sublane_mask_i32, num_read_sublanes, |
| all_sublanes_read); |
| } |
| |
| void populateVectorToLLOConversionPatterns(RewritePatternSet &patterns) { |
| addPattern<vector::ExtractOp>( |
| patterns, |
| [this](vector::ExtractOp op, vector::ExtractOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) -> LogicalResult { |
| if (!isRepresentableVectorType(op.getSourceVectorType()) || |
| op.getSourceVectorType().getElementTypeBitWidth() != 32 || |
| op.hasDynamicPosition()) { |
| return failure(); |
| } |
| for (int64_t pos : op.getStaticPosition()) { |
| if (pos != 0) { |
| return failure(); |
| } |
| } |
| rewriter.replaceOpWithNewOp<llo::VectorToScalarOp>(op, op.getType(), |
| subst.getVector()); |
| return success(); |
| }); |
| addPattern<vector::TransferReadOp>( |
| patterns, |
| [this](vector::TransferReadOp op, vector::TransferReadOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) -> LogicalResult { |
| if (!isRepresentableVectorType(op.getType())) { |
| return failure(); |
| } |
| auto memref_ty = op.getSource().getType(); |
| auto single_sublane_map = AffineMap::get( |
| memref_ty.getRank(), 0, |
| {rewriter.getAffineConstantExpr(0), |
| rewriter.getAffineDimExpr(memref_ty.getRank() - 1)}, |
| rewriter.getContext()); |
| auto regular_map = AffineMap::get( |
| memref_ty.getRank(), 0, |
| {rewriter.getAffineDimExpr(memref_ty.getRank() - 2), |
| rewriter.getAffineDimExpr(memref_ty.getRank() - 1)}, |
| rewriter.getContext()); |
| auto source_ty = getMemRefType(op.getSource()); |
| auto [source_addr, _] = unpackMemRef(subst.getSource(), rewriter); |
| if (!hasMemorySpace(source_ty, tpu::MemorySpace::vmem)) { |
| return failure(); |
| } |
| if (op.getPermutationMap() == single_sublane_map) { |
| std::optional<Value> offset = |
| indicesToVmemOffset(subst.getIndices(), 1, source_ty, rewriter); |
| if (!offset) { |
| return failure(); |
| } |
| auto [source, displacement] = vmemAddrOffsetToAddrDisplacement( |
| rewriter, source_addr, *offset); |
| rewriter.replaceOpWithNewOp<llo::VectorLoadOp>( |
| op, op.getType(), source, |
| /*displacement=*/displacement, |
| /*sublane_mask=*/nullptr, /*sublane_stride=*/0); |
| return success(); |
| } |
| if (op.getPermutationMap() == regular_map) { |
| std::optional<Value> offset = indicesToVmemOffset( |
| subst.getIndices(), sublaneCount(), source_ty, rewriter); |
| if (!offset) { |
| return failure(); |
| } |
| auto [source, displacement] = vmemAddrOffsetToAddrDisplacement( |
| rewriter, source_addr, *offset); |
| rewriter.replaceOpWithNewOp<llo::VectorLoadOp>( |
| op, op.getType(), source, |
| /*displacement=*/displacement, |
| /*sublane_mask=*/nullptr, /*sublane_stride=*/1); |
| return success(); |
| } |
| return failure(); |
| }); |
| addPattern<vector::LoadOp>( |
| patterns, |
| [this](vector::LoadOp op, vector::LoadOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) -> LogicalResult { |
| if (!isRepresentableVectorType(op.getType())) { |
| return failure(); |
| } |
| auto source_ty = getMemRefType(op.getBase()); |
| auto [source_addr, _] = unpackMemRef(subst.getBase(), rewriter); |
| if (!hasMemorySpace(source_ty, tpu::MemorySpace::vmem)) { |
| return failure(); |
| } |
| std::optional<Value> offset = indicesToVmemOffset( |
| subst.getIndices(), sublaneCount(), source_ty, rewriter); |
| if (!offset) { |
| return failure(); |
| } |
| auto [base, displacement] = |
| vmemAddrOffsetToAddrDisplacement(rewriter, source_addr, *offset); |
| rewriter.replaceOpWithNewOp<llo::VectorLoadOp>( |
| op, op.getType(), base, /*displacement=*/displacement, |
| /*sublane_mask=*/nullptr, /*sublane_stride=*/1); |
| return success(); |
| }); |
| addPattern<vector::StoreOp>( |
| patterns, [this](vector::StoreOp op, vector::StoreOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (!isRepresentableVectorType(subst.getValueToStore().getType())) { |
| return failure(); |
| } |
| auto source_ty = getMemRefType(op.getBase()); |
| auto [base_addr, _] = unpackMemRef(subst.getBase(), rewriter); |
| if (!hasMemorySpace(source_ty, tpu::MemorySpace::vmem)) { |
| return failure(); |
| } |
| std::optional<Value> offset = indicesToVmemOffset( |
| subst.getIndices(), sublaneCount(), source_ty, rewriter); |
| if (!offset) { |
| return failure(); |
| } |
| auto [base, displacement] = |
| vmemAddrOffsetToAddrDisplacement(rewriter, base_addr, *offset); |
| rewriter.replaceOpWithNewOp<llo::VectorStoreOp>( |
| op, /*address=*/base, /*displacement=*/displacement, |
| /*to_store=*/subst.getValueToStore(), /*sublane_mask=*/nullptr, |
| /*sublane_stride=*/1); |
| return success(); |
| }); |
| addPattern<vector::ContractionOp>( |
| patterns, |
| [this](vector::ContractionOp op, vector::ContractionOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) -> LogicalResult { |
| // This pattern is only meant to support a small class of matmuls. |
| if (subst.getKind() != vector::CombiningKind::ADD) { |
| return failure(); |
| } |
| auto ctx = rewriter.getContext(); |
| auto matmul_iterator_types = rewriter.getArrayAttr({ |
| vector::IteratorTypeAttr::get(ctx, |
| vector::IteratorType::parallel), |
| vector::IteratorTypeAttr::get(ctx, |
| vector::IteratorType::parallel), |
| vector::IteratorTypeAttr::get(ctx, |
| vector::IteratorType::reduction), |
| }); |
| if (subst.getIteratorTypes() != matmul_iterator_types) { |
| return failure(); |
| } |
| auto maps = subst.getIndexingMaps().getValue(); |
| auto lhs_map = mlir::AffineMapAttr::get(mlir::AffineMap::get( |
| 3, 0, |
| {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2)}, |
| rewriter.getContext())); |
| auto rhs_map = mlir::AffineMapAttr::get(mlir::AffineMap::get( |
| 3, 0, |
| {rewriter.getAffineDimExpr(2), rewriter.getAffineDimExpr(1)}, |
| rewriter.getContext())); |
| auto rhs_transp_map = mlir::AffineMapAttr::get(mlir::AffineMap::get( |
| 3, 0, |
| {rewriter.getAffineDimExpr(1), rewriter.getAffineDimExpr(2)}, |
| rewriter.getContext())); |
| auto out_map = mlir::AffineMapAttr::get(mlir::AffineMap::get( |
| 3, 0, |
| {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}, |
| rewriter.getContext())); |
| if (maps.size() != 3) { |
| return failure(); |
| } |
| if (maps[0] != lhs_map || |
| (maps[1] != rhs_map && maps[1] != rhs_transp_map) || |
| maps[2] != out_map) { |
| return failure(); |
| } |
| bool transposed = maps[1] == rhs_transp_map; |
| auto precision_attr = dyn_cast_if_present<tpu::ContractPrecisionAttr>( |
| op->getAttr("precision")); |
| rewriter.replaceOpWithNewOp<tpu::MatmulOp>( |
| op, op.getType(), op.getLhs(), op.getRhs(), op.getAcc(), |
| rewriter.getBoolAttr(false), rewriter.getBoolAttr(transposed), |
| precision_attr); |
| return success(); |
| }); |
| addPattern<vector::BroadcastOp>( |
| patterns, |
| [this](vector::BroadcastOp op, vector::BroadcastOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (!isRepresentableVectorType(op.getType())) { |
| return failure(); |
| } |
| if (!op.getSourceType().isSignlessInteger(32) && |
| !op.getSourceType().isF32()) { |
| return failure(); |
| } |
| rewriter.replaceOpWithNewOp<llo::ScalarToVectorOp>(op, op.getType(), |
| subst.getSource()); |
| return success(); |
| }); |
| addPattern<vector::TransposeOp>( |
| patterns, |
| [this](vector::TransposeOp op, vector::TransposeOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| int64_t transpose_width = target_->LaneCount(); |
| auto src_ty = op.getSourceVectorType(); |
| auto dst_ty = op.getResultVectorType(); |
| ArrayRef<int64_t> permutation = op.getPermutation(); |
| if (permutation.size() != 2 || permutation[0] != 1 || |
| permutation[1] != 0) { |
| return failure(); |
| } |
| int batching_factor; |
| if (src_ty == op.getResultVectorType() && |
| src_ty.getShape() == |
| ArrayRef<int64_t>{transpose_width, transpose_width}) { |
| batching_factor = 1; // No batching. |
| } else if (src_ty.getElementType() == dst_ty.getElementType() && |
| src_ty.getElementTypeBitWidth() == 16 && |
| src_ty.getShape() == |
| ArrayRef<int64_t>{transpose_width, |
| transpose_width * 2} && |
| dst_ty.getShape() == ArrayRef<int64_t>{transpose_width * 2, |
| transpose_width}) { |
| batching_factor = 2; |
| if (!target_->SupportsVsupp()) { |
| return failure(); |
| } |
| } else { |
| return failure(); |
| } |
| ValueRange src_vregs; |
| if (auto roll_op = |
| op.getVector().getDefiningOp<tpu::RollVectorsOp>()) { |
| src_vregs = roll_op.getOperands(); |
| } else { |
| return failure(); |
| } |
| auto packing = |
| target_->VectorScalarBitWidth() / src_ty.getElementTypeBitWidth(); |
| if (src_vregs.size() != transpose_width * batching_factor / |
| (target_->SublaneCount() * packing)) { |
| return failure(); |
| } |
| auto width = rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(128)); |
| std::vector<Value> result_vregs; |
| if (batching_factor == 1) { |
| llo::VxposeMode mode; |
| if (src_ty.getElementTypeBitWidth() == 32) { |
| mode = llo::VxposeMode::kB32; |
| } else if (src_ty.getElementTypeBitWidth() == 16) { |
| mode = llo::VxposeMode::kCompressedB16; |
| } else { |
| return failure(); |
| } |
| for (int32_t i = 0; i < src_vregs.size(); ++i) { |
| rewriter.create<llo::VectorTransposeOp>( |
| op.getLoc(), src_vregs[i], mode, width, i, src_vregs.size(), |
| /*xlu_id=*/0, /*source_bus=*/nullptr); |
| } |
| result_vregs.reserve(src_vregs.size()); |
| for (int32_t i = 0; i < src_vregs.size(); ++i) { |
| result_vregs.push_back( |
| rewriter.create<llo::VectorTransposeResultOp>( |
| op.getLoc(), src_vregs[0].getType())); |
| } |
| } else if (batching_factor == 2) { |
| CHECK_EQ(src_vregs.size() % 2, 0); |
| CHECK_EQ(src_ty.getElementTypeBitWidth(), 16); |
| for (int32_t i = 0; i < src_vregs.size() / 2; ++i) { |
| rewriter.create<llo::VectorTransposeBinaryCompressedB16Op>( |
| op.getLoc(), src_vregs[i * 2], src_vregs[i * 2 + 1], width, i, |
| src_vregs.size(), /*xlu_id=*/0, /*source_bus=*/nullptr); |
| } |
| result_vregs = std::vector<Value>(src_vregs.size(), nullptr); |
| for (int32_t i = 0; i < src_vregs.size() / 2; ++i) { |
| result_vregs[i] = rewriter.create<llo::VectorTransposeResultOp>( |
| op.getLoc(), src_vregs[0].getType()); |
| result_vregs[i + src_vregs.size() / 2] = |
| rewriter.create<llo::VectorTransposeResultOp>( |
| op.getLoc(), src_vregs[0].getType()); |
| } |
| } else { |
| LOG(FATAL) << "Unrecognized batching factor"; |
| } |
| rewriter.replaceOpWithNewOp<tpu::RollVectorsOp>(op, op.getType(), |
| result_vregs); |
| return success(); |
| }); |
| } |
| |
| template <typename Op, typename LLOOpType, int arity, |
| bool (*elem_type_filter)(Type ty)> |
| PatternFunc<Op> VectorElementwisePattern() { |
| return [this](Op op, typename Op::Adaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (op->getNumOperands() != arity) { |
| return failure(); |
| } |
| for (Value operand : op->getOperands()) { |
| if (!isRepresentableVectorType(operand.getType()) && |
| !isMaskVectorType(operand.getType())) { |
| return failure(); |
| } |
| if (!elem_type_filter( |
| cast<VectorType>(operand.getType()).getElementType())) { |
| return failure(); |
| } |
| } |
| rewriter.replaceOpWithNewOp<LLOOpType>(op, op.getType(), |
| subst.getOperands()); |
| return success(); |
| }; |
| } |
| |
| template <typename Op, typename LLOOpType, int arity, |
| bool (*elem_type_filter)(Type ty)> |
| PatternFunc<Op> ScalarElementwisePattern() { |
| return [](Op op, typename Op::Adaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (op->getNumOperands() != arity) { |
| return failure(); |
| } |
| for (Value operand : op->getOperands()) { |
| if (!elem_type_filter(operand.getType())) { |
| return failure(); |
| } |
| } |
| auto ty = subst.getOperands().front().getType(); |
| rewriter.replaceOpWithNewOp<LLOOpType>(op, ty, subst.getOperands()); |
| return success(); |
| }; |
| } |
| |
| void populateArithToLLOConversionPatterns(RewritePatternSet &patterns) { |
| constexpr bool (*f32)(Type) = [](Type ty) { return ty.isF32(); }; |
| constexpr bool (*i32)(Type) = [](Type ty) { |
| return ty.isSignlessInteger(32) || ty.isIndex(); |
| }; |
| constexpr bool (*i1)(Type) = [](Type ty) { |
| return ty.isSignlessInteger(1); |
| }; |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::AddFOp, llo::VectorAddF32Op, 2, f32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::SubFOp, llo::VectorSubF32Op, 2, f32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::MulFOp, llo::VectorMulF32Op, 2, f32>()); |
| addPattern(patterns, |
| VectorElementwisePattern<arith::MaximumFOp, llo::VectorMaxF32Op, |
| 2, f32>()); |
| addPattern(patterns, |
| VectorElementwisePattern<arith::MinimumFOp, llo::VectorMinF32Op, |
| 2, f32>()); |
| addPattern(patterns, |
| VectorElementwisePattern<arith::MaxSIOp, llo::VectorMaxS32Op, 2, |
| i32>()); |
| addPattern(patterns, |
| VectorElementwisePattern<arith::MinSIOp, llo::VectorMinS32Op, 2, |
| i32>()); |
| addPattern<arith::DivFOp>( |
| patterns, [this](arith::DivFOp op, arith::DivFOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (!isRepresentableVectorType(op.getType())) { |
| return failure(); |
| } |
| if (!cast<VectorType>(op.getType()).getElementType().isF32()) { |
| return failure(); |
| } |
| auto rhs_inv = rewriter.create<llo::VectorRecipF32Op>(op.getLoc(), |
| subst.getRhs()); |
| rewriter.replaceOpWithNewOp<llo::VectorMulF32Op>(op, subst.getLhs(), |
| rhs_inv); |
| return success(); |
| }); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::AndIOp, llo::VectorMaskAndOp, 2, i1>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::OrIOp, llo::VectorMaskOrOp, 2, i1>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::XOrIOp, llo::VectorMaskXorOp, 2, i1>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::NegFOp, llo::VectorNegF32Op, 1, f32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::AddIOp, llo::VectorAddS32Op, 2, i32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::SubIOp, llo::VectorSubS32Op, 2, i32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::MulIOp, llo::VectorMulS32Op, 2, i32>()); |
| addPattern(patterns, |
| VectorElementwisePattern<arith::DivSIOp, llo::VectorDivS32Op, 2, |
| i32>()); |
| addPattern(patterns, |
| VectorElementwisePattern<arith::RemSIOp, llo::VectorRemS32Op, 2, |
| i32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::ShLIOp, llo::VectorShiftLeftLogicalOp, |
| 2, i32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::ShRUIOp, llo::VectorShiftRightLogicalOp, |
| 2, i32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::ShRSIOp, |
| llo::VectorShiftRightArithmeticOp, 2, i32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::AndIOp, llo::VectorAndU32Op, 2, i32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::OrIOp, llo::VectorOrU32Op, 2, i32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<arith::XOrIOp, llo::VectorXOrU32Op, 2, i32>()); |
| addPattern<arith::XOrIOp>( |
| patterns, [](arith::XOrIOp op, arith::XOrIOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| // Lower arith.xori of i1 inputs. |
| for (Value operand : op.getOperands()) { |
| if (!operand.getType().isSignlessInteger(1)) { |
| return failure(); |
| } |
| } |
| |
| // arith.xori lhs, -1 |
| // is lowered to: |
| // llo.pnot lhs |
| |
| // MLIR canonicalization passes (run before this stage) move constants |
| // of commutative operations to the rhs |
| // (https://mlir.llvm.org/docs/Canonicalization/#globally-applied-rules), |
| // so we assume that the constant -1 is on the rhs. |
| bool is_not = false; |
| Value op_rhs = op.getOperands().back(); |
| auto crhs = op_rhs.getDefiningOp<arith::ConstantOp>(); |
| if (crhs != nullptr) { |
| is_not = cast<mlir::IntegerAttr>(crhs.getValue()).getInt() == -1; |
| } |
| |
| if (is_not) { |
| rewriter.replaceOpWithNewOp<llo::PredicateNegateOp>( |
| op, op.getType(), subst.getLhs()); |
| return success(); |
| } |
| |
| Value lhs = subst.getLhs(); |
| Value rhs = subst.getRhs(); |
| |
| // Implement XOR between lhs and rhs as follows: |
| // |
| // lhs_and_not_rhs = lhs & !rhs |
| // not_lhs_and_rhs = !lhs & rhs |
| // xor = lhs_and_not_rhs | not_lhs_and_rhs |
| |
| auto not_lhs = |
| rewriter.create<llo::PredicateNegateOp>(op.getLoc(), lhs); |
| auto not_rhs = |
| rewriter.create<llo::PredicateNegateOp>(op.getLoc(), rhs); |
| auto lhs_and_not_rhs = |
| rewriter.create<llo::PredicateAndOp>(op.getLoc(), lhs, not_rhs); |
| auto not_lhs_and_rhs = |
| rewriter.create<llo::PredicateAndOp>(op.getLoc(), not_lhs, rhs); |
| |
| rewriter.replaceOpWithNewOp<llo::PredicateOrOp>( |
| op, op.getType(), lhs_and_not_rhs, not_lhs_and_rhs); |
| return success(); |
| }); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<arith::NegFOp, llo::ScalarNegF32Op, 1, f32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<math::SqrtOp, llo::ScalarSqrtF32Op, 1, f32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<math::ExpOp, llo::ScalarExpF32Op, 1, f32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<math::LogOp, llo::ScalarLogF32Op, 1, f32>()); |
| addPattern(patterns, |
| ScalarElementwisePattern<math::RoundOp, llo::ScalarRoundF32Op, 1, |
| f32>()); |
| addPattern(patterns, |
| ScalarElementwisePattern<math::RoundEvenOp, |
| llo::ScalarRoundEvenF32Op, 1, f32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<arith::AddFOp, llo::ScalarAddF32Op, 2, f32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<arith::AddIOp, llo::ScalarAddS32Op, 2, i32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<arith::SubFOp, llo::ScalarSubF32Op, 2, f32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<arith::SubIOp, llo::ScalarSubS32Op, 2, i32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<arith::MulFOp, llo::ScalarMulF32Op, 2, f32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<arith::MulIOp, llo::ScalarMulS32Op, 2, i32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<arith::ShLIOp, llo::ScalarShllOp, 2, i32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<arith::ShRUIOp, llo::ScalarShrlOp, 2, i32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<arith::ShRSIOp, llo::ScalarShraOp, 2, i32>()); |
| addPattern(patterns, |
| ScalarElementwisePattern<arith::OrIOp, llo::ScalarBitwiseOrOp, 2, |
| i32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<arith::OrIOp, llo::PredicateOrOp, 2, i1>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<arith::AndIOp, llo::PredicateAndOp, 2, i1>()); |
| addPattern(patterns, |
| ScalarElementwisePattern<arith::AndIOp, llo::ScalarBitwiseAndOp, |
| 2, i32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<arith::DivFOp, llo::ScalarDivF32Op, 2, f32>()); |
| addPattern(patterns, |
| ScalarElementwisePattern<arith::DivSIOp, llo::ScalarDivS32Op, 2, |
| i32>()); |
| addPattern(patterns, |
| ScalarElementwisePattern<arith::RemSIOp, llo::ScalarRemS32Op, 2, |
| i32>()); |
| addPattern(patterns, |
| ScalarElementwisePattern<arith::MaximumFOp, llo::ScalarMaxF32Op, |
| 2, f32>()); |
| addPattern(patterns, |
| ScalarElementwisePattern<arith::MinimumFOp, llo::ScalarMinF32Op, |
| 2, f32>()); |
| addPattern(patterns, |
| ScalarElementwisePattern<arith::MaxSIOp, llo::ScalarMaxS32Op, 2, |
| i32>()); |
| addPattern(patterns, |
| ScalarElementwisePattern<arith::MinSIOp, llo::ScalarMinS32Op, 2, |
| i32>()); |
| addPattern<arith::ConstantOp>( |
| patterns, [this](arith::ConstantOp op, arith::ConstantOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| auto attr = op.getValue(); |
| if (attr.getType().isUnsignedInteger(32) || |
| attr.getType().isSignlessInteger(32) || |
| attr.getType().isSignlessInteger(1) || attr.getType().isF32()) { |
| goto supported; |
| } |
| if (auto splat = dyn_cast<SplatElementsAttr>(attr)) { |
| if (isRepresentableVectorType(op.getType()) || |
| isMaskVectorType(op.getType())) { |
| goto supported; |
| } |
| } |
| if (attr.getType().isIndex()) { |
| rewriter.replaceOpWithNewOp<llo::ConstantOp>( |
| op, rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr( |
| cast<mlir::IntegerAttr>(attr).getInt())); |
| return success(); |
| } |
| // Constants that are not representable in LLO should become dead. |
| rewriter.eraseOp(op); |
| return success(); |
| supported: |
| rewriter.replaceOpWithNewOp<llo::ConstantOp>(op, op.getType(), attr); |
| return success(); |
| }); |
| addPattern<arith::IndexCastOp>( |
| patterns, [](arith::IndexCastOp op, arith::IndexCastOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| auto in_ty = op.getIn().getType(); |
| auto out_ty = op.getType(); |
| if ((in_ty.isIndex() && out_ty.isSignlessInteger(32)) || |
| (in_ty.isSignlessInteger(32) && out_ty.isIndex())) { |
| rewriter.replaceOp(op, subst.getIn()); |
| return success(); |
| } |
| op.emitOpError("Unsupported cast"); |
| return failure(); |
| }); |
| addPattern<arith::CmpFOp>( |
| patterns, [this](arith::CmpFOp op, arith::CmpFOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| auto arg_ty = op.getOperandTypes().front(); |
| if (auto vec_ty = dyn_cast<VectorType>(arg_ty)) { |
| if (!isRepresentableVectorType(vec_ty) || |
| !isMaskVectorType(op.getType()) || |
| !vec_ty.getElementType().isF32()) { |
| return failure(); |
| } |
| switch (op.getPredicate()) { |
| case mlir::arith::CmpFPredicate::OEQ: |
| rewriter.replaceOpWithNewOp<llo::VectorCmpEqF32Op>( |
| op, op.getType(), subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpFPredicate::ONE: |
| rewriter.replaceOpWithNewOp<llo::VectorCmpNeF32Op>( |
| op, op.getType(), subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpFPredicate::OLT: |
| rewriter.replaceOpWithNewOp<llo::VectorCmpLtF32Op>( |
| op, op.getType(), subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpFPredicate::OLE: |
| rewriter.replaceOpWithNewOp<llo::VectorCmpLeF32Op>( |
| op, op.getType(), subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpFPredicate::OGT: |
| rewriter.replaceOpWithNewOp<llo::VectorCmpGtF32Op>( |
| op, op.getType(), subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpFPredicate::OGE: |
| rewriter.replaceOpWithNewOp<llo::VectorCmpGeF32Op>( |
| op, op.getType(), subst.getLhs(), subst.getRhs()); |
| break; |
| default: |
| return failure(); |
| } |
| } else { |
| if (!arg_ty.isF32()) { |
| return failure(); |
| } |
| switch (op.getPredicate()) { |
| case mlir::arith::CmpFPredicate::OEQ: |
| rewriter.replaceOpWithNewOp<llo::ScalarCmpEqF32Op>( |
| op, subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpFPredicate::ONE: |
| rewriter.replaceOpWithNewOp<llo::ScalarCmpNeF32Op>( |
| op, subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpFPredicate::OLT: |
| rewriter.replaceOpWithNewOp<llo::ScalarCmpLtF32Op>( |
| op, subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpFPredicate::OLE: |
| rewriter.replaceOpWithNewOp<llo::ScalarCmpLeF32Op>( |
| op, subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpFPredicate::OGT: |
| rewriter.replaceOpWithNewOp<llo::ScalarCmpGtF32Op>( |
| op, subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpFPredicate::OGE: |
| rewriter.replaceOpWithNewOp<llo::ScalarCmpGeF32Op>( |
| op, subst.getLhs(), subst.getRhs()); |
| break; |
| default: |
| return failure(); |
| } |
| } |
| return success(); |
| }); |
| addPattern<arith::CmpIOp>( |
| patterns, [this](arith::CmpIOp op, arith::CmpIOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| auto arg_ty = op.getOperandTypes().front(); |
| if (auto vec_ty = dyn_cast<VectorType>(arg_ty)) { |
| if (!isRepresentableVectorType(vec_ty) || |
| !isMaskVectorType(op.getType()) || |
| !vec_ty.getElementType().isSignlessInteger(32)) { |
| return failure(); |
| } |
| switch (op.getPredicate()) { |
| case mlir::arith::CmpIPredicate::eq: |
| rewriter.replaceOpWithNewOp<llo::VectorCmpEqS32Op>( |
| op, op.getType(), subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpIPredicate::ne: |
| rewriter.replaceOpWithNewOp<llo::VectorCmpNeS32Op>( |
| op, op.getType(), subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpIPredicate::slt: |
| rewriter.replaceOpWithNewOp<llo::VectorCmpLtS32Op>( |
| op, op.getType(), subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpIPredicate::sle: |
| rewriter.replaceOpWithNewOp<llo::VectorCmpLeS32Op>( |
| op, op.getType(), subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpIPredicate::sgt: |
| rewriter.replaceOpWithNewOp<llo::VectorCmpGtS32Op>( |
| op, op.getType(), subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpIPredicate::sge: |
| rewriter.replaceOpWithNewOp<llo::VectorCmpGeS32Op>( |
| op, op.getType(), subst.getLhs(), subst.getRhs()); |
| break; |
| default: |
| return failure(); |
| } |
| } else { |
| if (!arg_ty.isSignlessInteger(32) && !arg_ty.isIndex()) { |
| return failure(); |
| } |
| switch (op.getPredicate()) { |
| case mlir::arith::CmpIPredicate::eq: |
| rewriter.replaceOpWithNewOp<llo::ScalarCmpEqS32Op>( |
| op, subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpIPredicate::ne: |
| rewriter.replaceOpWithNewOp<llo::ScalarCmpNeS32Op>( |
| op, subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpIPredicate::slt: |
| rewriter.replaceOpWithNewOp<llo::ScalarCmpLtS32Op>( |
| op, subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpIPredicate::sle: |
| rewriter.replaceOpWithNewOp<llo::ScalarCmpLeS32Op>( |
| op, subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpIPredicate::sgt: |
| rewriter.replaceOpWithNewOp<llo::ScalarCmpGtS32Op>( |
| op, subst.getLhs(), subst.getRhs()); |
| break; |
| case mlir::arith::CmpIPredicate::sge: |
| rewriter.replaceOpWithNewOp<llo::ScalarCmpGeS32Op>( |
| op, subst.getLhs(), subst.getRhs()); |
| break; |
| default: |
| return failure(); |
| } |
| } |
| return success(); |
| }); |
| addPattern<arith::SIToFPOp>( |
| patterns, [this](arith::SIToFPOp op, arith::SIToFPOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| auto arg_ty = subst.getIn().getType(); |
| if (VectorType in_ty = dyn_cast<VectorType>(arg_ty)) { |
| VectorType out_ty = dyn_cast<VectorType>(op.getType()); |
| if (!in_ty || !out_ty || !isRepresentableVectorType(in_ty) || |
| !isRepresentableVectorType(out_ty)) { |
| return failure(); |
| } |
| if (in_ty.getElementType() == rewriter.getI32Type() && |
| out_ty.getElementType() == rewriter.getF32Type()) { |
| rewriter.replaceOpWithNewOp<llo::VectorConvertS32ToF32Op>( |
| op, out_ty, subst.getIn()); |
| return success(); |
| } |
| } else if (arg_ty == rewriter.getI32Type() && |
| op.getType() == rewriter.getF32Type()) { |
| rewriter.replaceOpWithNewOp<llo::ScalarConvertS32ToF32Op>( |
| op, op.getType(), subst.getIn()); |
| return success(); |
| } |
| return failure(); |
| }); |
| addPattern<arith::FPToSIOp>( |
| patterns, [this](arith::FPToSIOp op, arith::FPToSIOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| auto arg_ty = subst.getIn().getType(); |
| if (VectorType in_ty = dyn_cast<VectorType>(arg_ty)) { |
| VectorType out_ty = dyn_cast<VectorType>(op.getType()); |
| if (!in_ty || !out_ty || !isRepresentableVectorType(in_ty) || |
| !isRepresentableVectorType(out_ty)) { |
| return failure(); |
| } |
| if (in_ty.getElementType() == rewriter.getF32Type() && |
| out_ty.getElementType() == rewriter.getI32Type()) { |
| rewriter.replaceOpWithNewOp< |
| llo::VectorConvertF32ToS32TowardsZeroPseudoOp>(op, out_ty, |
| subst.getIn()); |
| return success(); |
| } |
| } else if (arg_ty == rewriter.getF32Type() && |
| op.getType() == rewriter.getI32Type()) { |
| rewriter.replaceOpWithNewOp< |
| llo::ScalarConvertF32ToS32TowardsZeroPseudoOp>(op, op.getType(), |
| subst.getIn()); |
| return success(); |
| } |
| return failure(); |
| }); |
| addPattern<arith::SelectOp>( |
| patterns, [this](arith::SelectOp op, arith::SelectOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (auto ty = dyn_cast<VectorType>(op.getType())) { |
| if (!isRepresentableVectorType(ty)) { |
| return failure(); |
| } |
| if (!isa<VectorType>(subst.getCondition().getType())) { |
| return failure(); |
| } |
| rewriter.replaceOpWithNewOp<llo::VectorSelectOp>( |
| op, op.getType(), subst.getCondition(), subst.getTrueValue(), |
| subst.getFalseValue()); |
| return success(); |
| } |
| if (op.getType().isSignlessInteger(32) || op.getType().isIndex() || |
| op.getType().isF32()) { |
| rewriter.replaceOpWithNewOp<llo::ScalarSelectOp>( |
| op, subst.getTrueValue().getType(), subst.getCondition(), |
| subst.getTrueValue(), subst.getFalseValue()); |
| return success(); |
| } |
| return failure(); |
| }); |
| addPattern<arith::ExtSIOp>( |
| patterns, [](arith::ExtSIOp op, arith::ExtSIOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| // Narrow integer types are already converted to 32-bit integers. |
| if (!op.getType().isSignlessInteger(32) || |
| !subst.getIn().getType().isSignlessInteger(32)) { |
| return failure(); |
| } |
| rewriter.replaceOp(op, subst.getIn()); |
| return success(); |
| }); |
| addPattern<arith::TruncIOp>( |
| patterns, [](arith::TruncIOp op, arith::TruncIOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (!op.getIn().getType().isSignlessInteger(32)) { |
| return failure(); |
| } |
| unsigned bitwidth = op.getType().getIntOrFloatBitWidth(); |
| rewriter.replaceOpWithNewOp<llo::ScalarBitwiseAndOp>( |
| op, subst.getIn(), |
| rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr((1 << bitwidth) - 1))); |
| return success(); |
| }); |
| addPattern<arith::ExtUIOp>( |
| patterns, [](arith::ExtUIOp op, arith::ExtUIOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| Type in_ty = subst.getIn().getType(); |
| if (in_ty == op.getType()) { |
| // The type of the new value is already equal to the destination |
| // type, there is nothing to do. |
| rewriter.replaceOp(op, subst.getIn()); |
| return success(); |
| } |
| |
| if (auto in_vec_ty = dyn_cast<VectorType>(in_ty)) { |
| auto out_vec_ty = cast<VectorType>(op.getType()); |
| if (in_vec_ty.getElementTypeBitWidth() != 1 || |
| out_vec_ty.getElementTypeBitWidth() != 32) { |
| return failure(); |
| } |
| auto make_splat = [&](int64_t value) -> Value { |
| return rewriter.create<llo::ConstantOp>( |
| op.getLoc(), out_vec_ty, |
| SplatElementsAttr::get(out_vec_ty, |
| rewriter.getI32IntegerAttr(value))); |
| }; |
| rewriter.replaceOpWithNewOp<llo::VectorSelectOp>( |
| op, out_vec_ty, subst.getIn(), make_splat(1), make_splat(0)); |
| return success(); |
| } |
| if (in_ty.isSignlessInteger(1) && |
| op.getType().isSignlessInteger(32)) { |
| auto make_const = [&](int64_t value) -> Value { |
| return rewriter.create<llo::ConstantOp>( |
| op.getLoc(), op.getType(), rewriter.getI32IntegerAttr(value)); |
| }; |
| rewriter.replaceOpWithNewOp<llo::ScalarSelectOp>( |
| op, op.getType(), subst.getIn(), make_const(1), make_const(0)); |
| return success(); |
| } |
| return failure(); |
| }); |
| addPattern<arith::BitcastOp>( |
| patterns, [](arith::BitcastOp op, arith::BitcastOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| auto in_ty = op.getIn().getType(); |
| auto out_ty = op.getType(); |
| auto in_vty = dyn_cast<VectorType>(in_ty); |
| auto out_vty = dyn_cast<VectorType>(out_ty); |
| if (in_vty && out_vty && |
| in_vty.getElementTypeBitWidth() == |
| out_vty.getElementTypeBitWidth()) { |
| rewriter.replaceOpWithNewOp<llo::VectorBitcastOp>(op, out_ty, |
| subst.getIn()); |
| return success(); |
| } |
| if (in_ty.getIntOrFloatBitWidth() == out_ty.getIntOrFloatBitWidth()) { |
| // We only allow 32-bit types in LLO and any narrower type is |
| // represented by i32, with the valid data being in the low bits. |
| // So, for casts between 32-bit types we need to do the bit case. |
| // But, for casts between narrower types, we shouldn't do anything. |
| if (in_ty.getIntOrFloatBitWidth() == 32) { |
| rewriter.replaceOpWithNewOp<llo::ScalarBitcastOp>(op, out_ty, |
| subst.getIn()); |
| } else { |
| rewriter.replaceOp(op, subst.getIn()); |
| } |
| return success(); |
| } |
| op.emitOpError("Failed to arith::bitcast") |
| << in_ty << " to " << out_ty; |
| return failure(); |
| }); |
| } |
| |
| void populateMathToLLOConversionPatterns(RewritePatternSet &patterns) { |
| constexpr bool (*f32)(Type) = [](Type ty) { return ty.isF32(); }; |
| constexpr bool (*i32)(Type) = [](Type ty) { |
| return ty.isSignlessInteger(32) || ty.isIndex(); |
| }; |
| addPattern(patterns, |
| VectorElementwisePattern<math::RsqrtOp, llo::VectorRsqrtF32Op, 1, |
| f32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<math::SqrtOp, llo::VectorSqrtF32Op, 1, f32>()); |
| addPattern( |
| patterns, |
| ScalarElementwisePattern<math::CountLeadingZerosOp, |
| llo::ScalarCountLeadingZerosOp, 1, i32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<math::ExpOp, llo::VectorExpF32Op, 1, f32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<math::Exp2Op, llo::VectorPow2F32Op, 1, f32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<math::CosOp, llo::VectorCosF32Op, 1, f32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<math::SinOp, llo::VectorSinF32Op, 1, f32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<math::TanhOp, llo::VectorTanhF32Op, 1, f32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<math::LogOp, llo::VectorLogF32Op, 1, f32>()); |
| addPattern(patterns, |
| VectorElementwisePattern<math::Log1pOp, llo::VectorLog1pF32Op, 1, |
| f32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<math::PowFOp, llo::VectorPowF32Op, 2, f32>()); |
| addPattern( |
| patterns, |
| VectorElementwisePattern<math::AbsFOp, llo::VectorAbsF32Op, 1, f32>()); |
| addPattern(patterns, |
| VectorElementwisePattern<math::RoundOp, llo::VectorRoundF32Op, 1, |
| f32>()); |
| addPattern(patterns, |
| VectorElementwisePattern<math::RoundEvenOp, |
| llo::VectorRoundEvenF32Op, 1, f32>()); |
| addPattern<math::AbsIOp>( |
| patterns, [this](math::AbsIOp op, math::AbsIOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| Value operand = subst.getOperand(); |
| Type operandType = operand.getType(); |
| |
| if (!isRepresentableVectorType(operandType)) { |
| return failure(); |
| } |
| |
| VectorType operandVectorType = cast<VectorType>(operandType); |
| if (operandVectorType.getElementType() != rewriter.getI32Type()) { |
| return failure(); |
| } |
| |
| // Build the following instruction sequence: |
| // cmp = vcmp.lt.32 input, 0 |
| // neg = vneg.32 input |
| // out = vselect cmp, neg, input |
| |
| Value zero = rewriter.create<llo::ConstantOp>( |
| op.getLoc(), operandVectorType, |
| SplatElementsAttr::get(operandVectorType, |
| rewriter.getI32IntegerAttr(0))); |
| VectorType cmpVectorType = VectorType::get( |
| operandVectorType.getShape(), rewriter.getI1Type()); |
| Value cmp = rewriter.create<llo::VectorCmpLtS32Op>( |
| op.getLoc(), cmpVectorType, operand, zero); |
| Value neg = |
| rewriter.create<llo::VectorNegS32Op>(op.getLoc(), operand); |
| rewriter.replaceOpWithNewOp<llo::VectorSelectOp>( |
| op, operandVectorType, cmp, neg, operand); |
| return success(); |
| }); |
| } |
| |
| void populateMemrefToLLOConversionPatterns(RewritePatternSet &patterns) { |
| addPattern<memref::LoadOp>( |
| patterns, [this](memref::LoadOp op, memref::LoadOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| auto ty = getMemRefType(op.getMemRef()); |
| auto [ref_addr, _] = unpackMemRef(subst.getMemref(), rewriter); |
| if (!hasMemorySpace(op.getMemRefType(), tpu::MemorySpace::smem)) { |
| return failure(); |
| } |
| if (!op.getType().isIntOrFloat() || |
| op.getType().getIntOrFloatBitWidth() > 32) { |
| return failure(); |
| } |
| auto [offset, part] = |
| indicesToOffset(op.getLoc(), ty, subst.getIndices(), rewriter); |
| if (!offset) { |
| return failure(); |
| } |
| Type new_type; |
| if (op.getType().isSignlessInteger()) { |
| new_type = rewriter.getI32Type(); |
| } else if (isa<FloatType>(op.getType())) { |
| new_type = rewriter.getF32Type(); |
| } else { |
| return failure(); |
| } |
| Value result = rewriter.create<llo::ScalarLoadOp>( |
| op.getLoc(), new_type, ref_addr, offset); |
| if (part != nullptr) { |
| unsigned bitwidth = op.getType().getIntOrFloatBitWidth(); |
| unsigned base_mask = (1 << bitwidth) - 1; |
| Value mask = rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(base_mask)); |
| Value shift = rewriter.create<llo::ScalarMulS32Op>( |
| op.getLoc(), part, |
| rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(bitwidth))); |
| // Bitcast result to I32 type for shifting. |
| result = rewriter.create<llo::ScalarBitcastOp>( |
| op.getLoc(), rewriter.getI32Type(), result); |
| // Shift and mask the subelement we want. |
| result = rewriter.create<llo::ScalarBitwiseAndOp>( |
| op.getLoc(), mask, |
| rewriter.create<llo::ScalarShrlOp>(op.getLoc(), result, shift)); |
| } |
| rewriter.replaceOp(op, result); |
| return success(); |
| }); |
| addPattern<memref::StoreOp>( |
| patterns, [this](memref::StoreOp op, memref::StoreOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| auto ty = getMemRefType(op.getMemRef()); |
| auto [ref_addr, _] = unpackMemRef(subst.getMemref(), rewriter); |
| if (!hasMemorySpace(op.getMemRefType(), tpu::MemorySpace::smem)) { |
| return failure(); |
| } |
| if (!op.getValue().getType().isIntOrFloat() || |
| op.getValue().getType().getIntOrFloatBitWidth() > 32) { |
| return failure(); |
| } |
| auto [offset, part] = |
| indicesToOffset(op.getLoc(), ty, subst.getIndices(), rewriter); |
| Value addr = rewriter.create<llo::ScalarAddressSmemOp>( |
| op.getLoc(), ref_addr, offset); |
| Value to_store = subst.getValue(); |
| if (part != nullptr) { |
| Value zero = rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(0)); |
| unsigned bitwidth = op.getValue().getType().getIntOrFloatBitWidth(); |
| unsigned base_mask = (1 << bitwidth) - 1; |
| Value update_mask = rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(base_mask)); |
| Value shift = rewriter.create<llo::ScalarMulS32Op>( |
| op.getLoc(), part, |
| rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(bitwidth))); |
| // Here, we take an inverse of the updated bits. |
| Value existing_mask = rewriter.create<llo::ScalarBitwiseXorOp>( |
| op.getLoc(), |
| // Shift the mask to where the update will happen. |
| rewriter.create<llo::ScalarShllOp>(op.getLoc(), update_mask, |
| shift), |
| rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(0xFFFFFFFF))); |
| // Bitcast to_store to I32 type for shifting. |
| to_store = rewriter.create<llo::ScalarBitcastOp>( |
| op.getLoc(), rewriter.getI32Type(), to_store); |
| // Blend the shifted update with existing values with the updated |
| // location masked out. |
| to_store = rewriter.create<llo::ScalarBitwiseOrOp>( |
| op.getLoc(), |
| rewriter.create<llo::ScalarShllOp>(op.getLoc(), to_store, |
| shift), |
| rewriter.create<llo::ScalarBitwiseAndOp>( |
| op.getLoc(), existing_mask, |
| rewriter.create<llo::ScalarLoadOp>( |
| op.getLoc(), to_store.getType(), addr, zero))); |
| } |
| rewriter.replaceOpWithNewOp<llo::ScalarStoreOp>(op, addr, to_store); |
| return success(); |
| }); |
| addPattern<memref::AllocaOp>( |
| patterns, [this](memref::AllocaOp op, memref::AllocaOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| auto ref_ty = op.getType(); |
| auto size_bytes = getMemRefSizeInBytes(ref_ty); |
| if (hasMemorySpace(ref_ty, tpu::MemorySpace::vmem)) { |
| if (!size_bytes || |
| *size_bytes % target_->VmemWordSizeBytes() != 0) { |
| return failure(); |
| } |
| rewriter.replaceOpWithNewOp<llo::AllocaVmemOp>( |
| op, rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(*size_bytes / |
| target_->VmemWordSizeBytes())); |
| return success(); |
| } |
| if (hasMemorySpace(ref_ty, tpu::MemorySpace::smem)) { |
| if (!size_bytes || |
| *size_bytes % target_->SmemWordSizeBytes() != 0) { |
| return failure(); |
| } |
| rewriter.replaceOpWithNewOp<llo::AllocaSmemOp>( |
| op, rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(*size_bytes / |
| target_->SmemWordSizeBytes())); |
| return success(); |
| } |
| // TODO: Support semaphore allocations |
| op.emitOpError( |
| "Cannot allocate ref in non-VMEM/SMEM memory space using " |
| "memref.alloca."); |
| return failure(); |
| }); |
| } |
| |
| void populateCFToLLOConversionPatterns(RewritePatternSet &patterns) { |
| addPattern<cf::AssertOp>(patterns, [](cf::AssertOp op, |
| cf::AssertOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.replaceOpWithNewOp<llo::ErrorIfOp>( |
| op, |
| rewriter.create<llo::PredicateNegateOp>(op.getLoc(), subst.getArg()), |
| op.getMsgAttr()); |
| return success(); |
| }); |
| } |
| |
| void populateTPUToLLOConversionPatterns(RewritePatternSet &patterns) { |
| addPattern<tpu::CreateMaskOp>( |
| patterns, [this](tpu::CreateMaskOp op, tpu::CreateMaskOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (op.getHigh().size() != 2) { |
| return failure(); |
| } |
| if (!isMaskVectorType(op.getType())) { |
| return failure(); |
| } |
| auto get_constant = [](Value v) -> std::optional<int64_t> { |
| auto cst_op = v.getDefiningOp<arith::ConstantOp>(); |
| if (!cst_op) { |
| return std::nullopt; |
| } |
| return cast<IntegerAttr>(cst_op.getValue()).getInt(); |
| }; |
| auto sublane_low = get_constant(op.getLow()[0]); |
| auto sublane_high = get_constant(op.getHigh()[0]); |
| auto lane_low = get_constant(op.getLow()[1]); |
| auto lane_high = get_constant(op.getHigh()[1]); |
| if (!sublane_low || !sublane_high || !lane_low || !lane_high) { |
| return failure(); |
| } |
| if (*sublane_low < 0 || *sublane_low > sublaneCount() || |
| *sublane_high < 0 || *sublane_high > sublaneCount() || |
| *lane_low < 0 || *lane_low > laneCount() || *lane_high < 0 || |
| *lane_high > laneCount()) { |
| return failure(); |
| } |
| if (*sublane_low > *sublane_high || *lane_low > *lane_high) { |
| return failure(); |
| } |
| if (*sublane_low == *sublane_high || *lane_low == *lane_high) { |
| rewriter.replaceOpWithNewOp<llo::ConstantOp>( |
| op, op.getType(), |
| SplatElementsAttr::get(cast<ShapedType>(op.getType()), false)); |
| return success(); |
| } |
| // LLO create mask has inclusive upper bounds. |
| rewriter.replaceOpWithNewOp<llo::VectorCreateMaskOp>( |
| op, op.getType(), |
| /*sublane_start=*/*sublane_low, |
| /*sublane_end=*/(*sublane_high - 1), |
| /*lane_start=*/*lane_low, /*lane_end=*/(*lane_high - 1)); |
| return success(); |
| }); |
| addPattern<tpu::CreateSubelementMaskOp>( |
| patterns, [this](tpu::CreateSubelementMaskOp op, |
| tpu::CreateSubelementMaskOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (!isMaskVectorType(op.getType())) { |
| return failure(); |
| } |
| auto vty = cast<VectorType>(op.getType()); |
| if (vty.getRank() != 3 || vty.getDimSize(2) != op.getNumSubelems()) { |
| return failure(); |
| } |
| if (0 > op.getFrom() || op.getFrom() >= op.getTo() || |
| op.getTo() > target_->SublaneCount() * op.getNumSubelems()) { |
| return failure(); |
| } |
| auto create_sublane_mask = [&rewriter, |
| &op](int32_t packed_limit) -> Value { |
| return rewriter.create<llo::VectorCreateSublaneMaskOp>( |
| op.getLoc(), op.getType(), |
| rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(packed_limit))); |
| }; |
| if (target_->DeepseaVersion() == ::tpu::TpuVersion::kPufferfish) { |
| if (op.getNumSubelems() != 1 && op.getNumSubelems() != 2) { |
| return failure(); |
| } |
| int32_t from = op.getFrom(); |
| int32_t to = op.getTo(); |
| if (op.getNumSubelems() == 1) { |
| from *= 2; |
| to *= 2; |
| } |
| Value mask = nullptr; |
| if (from != 0) { |
| uint32_t packed_from = from | (from << 16); |
| mask = rewriter.create<llo::VectorMaskNegateOp>( |
| op.getLoc(), op.getType(), create_sublane_mask(packed_from)); |
| } |
| if (to != target_->SublaneCount() * 2) { |
| uint32_t packed_to = to | (to << 16); |
| auto lt_to = create_sublane_mask(packed_to); |
| if (!mask) { |
| mask = lt_to; |
| } else { |
| mask = rewriter.create<llo::VectorMaskAndOp>( |
| op.getLoc(), op.getType(), mask, lt_to); |
| } |
| } |
| if (!mask) { |
| rewriter.replaceOpWithNewOp<llo::ConstantOp>( |
| op, op.getType(), |
| rewriter.getIntegerAttr(rewriter.getI1Type(), true)); |
| return success(); |
| } |
| rewriter.replaceOp(op, mask); |
| return success(); |
| } |
| #if GCE_PERMITS_VIPERLITE || GCE_PERMITS_VIPERFISH |
| if (target_->DeepseaVersion() == ::tpu::TpuVersion::kViperfish) { |
| if (op.getNumSubelems() != 1 && op.getNumSubelems() != 2 && |
| op.getNumSubelems() != 4) { |
| return failure(); |
| } |
| uint32_t native_subelems = 4; |
| int32_t from = |
| op.getFrom() * (native_subelems / op.getNumSubelems()); |
| // We subtract 1, because upper bound is inclusive in Viperfish. |
| int32_t to = |
| (op.getTo() * (native_subelems / op.getNumSubelems())) - 1; |
| CHECK_LE(from, to); |
| CHECK_LE(from, 2 << 4); |
| CHECK_LE(to, 2 << 4); |
| int32_t packed_limit = from | (to << 8); |
| rewriter.replaceOp(op, create_sublane_mask(packed_limit)); |
| return success(); |
| } |
| #endif |
| return failure(); |
| }); |
| addPattern<tpu::AllReduceOp>( |
| patterns, [](tpu::AllReduceOp op, tpu::AllReduceOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (op.getDim() == 0) { |
| switch (op.getKind()) { |
| case tpu::ReductionKind::SUM: |
| rewriter.replaceOpWithNewOp<llo::VectorAddSublaneReduceF32Op>( |
| op, subst.getInput()); |
| break; |
| case tpu::ReductionKind::MAX: |
| rewriter.replaceOpWithNewOp<llo::VectorMaxSublaneReduceF32Op>( |
| op, subst.getInput()); |
| break; |
| default: |
| return failure(); |
| } |
| return success(); |
| } |
| if (op.getDim() == 1) { |
| switch (op.getKind()) { |
| case tpu::ReductionKind::SUM: |
| rewriter.replaceOpWithNewOp<llo::VectorAddReduceF32Op>( |
| op, subst.getInput()); |
| break; |
| case tpu::ReductionKind::MAX: |
| rewriter.replaceOpWithNewOp<llo::VectorMaxReduceF32Op>( |
| op, subst.getInput()); |
| break; |
| } |
| return success(); |
| } |
| return failure(); |
| }); |
| addPattern<tpu::BitcastVregOp>( |
| patterns, [](tpu::BitcastVregOp op, tpu::BitcastVregOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.replaceOpWithNewOp<llo::VectorBitcastOp>(op, op.getType(), |
| subst.getInput()); |
| return success(); |
| }); |
| addPattern<tpu::RollVectorsOp>( |
| patterns, [](tpu::RollVectorsOp op, tpu::RollVectorsOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.eraseOp(op); |
| return success(); |
| }); |
| addPattern<tpu::UnrollVectorsOp>( |
| patterns, [](tpu::UnrollVectorsOp op, tpu::UnrollVectorsOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (auto rop = subst.getInput().getDefiningOp<tpu::RollVectorsOp>()) { |
| if (rop.getNumOperands() != op.getNumResults()) { |
| return failure(); |
| } |
| for (auto [v1, v2] : |
| llvm::zip(rop.getOperandTypes(), op.getResultTypes())) { |
| if (v1 != v2) { |
| return failure(); |
| } |
| } |
| rewriter.replaceOp(op, rop->getOperands()); |
| return success(); |
| } |
| return failure(); |
| }); |
| addPattern<tpu::LoadOp>( |
| patterns, |
| [this](tpu::LoadOp op, tpu::LoadOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) -> LogicalResult { |
| if (!isRepresentableVectorType(op.getType())) { |
| return failure(); |
| } |
| auto base_ty = getMemRefType(op.getBase()); |
| auto [base_addr, _] = unpackMemRef(subst.getBase(), rewriter); |
| if (!hasMemorySpace(base_ty, tpu::MemorySpace::vmem)) { |
| return failure(); |
| } |
| auto sublane_mask = op.getSublaneMask(); |
| if (sublane_mask.size() != sublaneCount()) { |
| return failure(); |
| } |
| auto [sublane_mask_i32, num_read_sublanes, all_sublanes_read] = |
| encodeSublaneMask(sublane_mask); |
| std::optional<Value> offset = indicesToVmemOffset( |
| subst.getIndices(), num_read_sublanes, base_ty, rewriter, |
| /*consistent_directions=*/false); |
| if (!offset) { |
| return op.emitOpError("Failed to convert indices to linear offset"); |
| } |
| Value sublane_mask_v = nullptr; |
| if (!all_sublanes_read) { |
| sublane_mask_v = rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getIntegerType(32, false), |
| rewriter.getUI32IntegerAttr(sublane_mask_i32)); |
| } |
| |
| uint32_t sublane_stride = op.getSublaneStride().value_or(1); |
| if (sublane_stride <= 1) { |
| auto [base, displacement] = |
| vmemAddrOffsetToAddrDisplacement(rewriter, base_addr, *offset); |
| rewriter.replaceOpWithNewOp<llo::VectorLoadOp>( |
| op, op.getType(), base, displacement, sublane_mask_v, |
| sublane_stride); |
| } else { |
| Value addr = base_addr; |
| if (*offset) { |
| addr = rewriter.create<llo::ScalarAddressVmemOp>(op.getLoc(), |
| addr, *offset); |
| } |
| rewriter.replaceOpWithNewOp<llo::VldWithArbitrarySlaneStrideOp>( |
| op, op.getType(), addr, sublane_stride, sublaneCount(), |
| sublane_mask_v); |
| } |
| return success(); |
| }); |
| addPattern<tpu::StoreOp>( |
| patterns, [this](tpu::StoreOp op, tpu::StoreOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (!isRepresentableVectorType(subst.getValueToStore().getType())) { |
| return failure(); |
| } |
| auto source_ty = getMemRefType(op.getBase()); |
| auto [base_addr, _] = unpackMemRef(subst.getBase(), rewriter); |
| if (!hasMemorySpace(source_ty, tpu::MemorySpace::vmem)) { |
| return failure(); |
| } |
| auto sublane_mask = op.getSublaneMask(); |
| if (sublane_mask.size() != sublaneCount()) { |
| return failure(); |
| } |
| auto [sublane_mask_i32, num_read_sublanes, all_sublanes_read] = |
| encodeSublaneMask(sublane_mask); |
| std::optional<Value> offset = indicesToVmemOffset( |
| subst.getIndices(), num_read_sublanes, source_ty, rewriter, |
| /*consistent_directions=*/false); |
| if (!offset) { |
| return failure(); |
| } |
| Value sublane_mask_v = nullptr; |
| if (!all_sublanes_read) { |
| sublane_mask_v = rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getIntegerType(32, false), |
| rewriter.getUI32IntegerAttr(sublane_mask_i32)); |
| } |
| auto sublane_stride = op.getSublaneStride().value_or(1); |
| if (sublane_stride <= 1) { |
| auto [base, displacement] = |
| vmemAddrOffsetToAddrDisplacement(rewriter, base_addr, *offset); |
| if (auto mask = op.getMask()) { |
| rewriter.replaceOpWithNewOp<llo::VectorStoreMaskedOp>( |
| op, /*address=*/base, |
| /*displacement=*/displacement, |
| /*mask=*/mask, /*to_store=*/subst.getValueToStore(), |
| /*sublane_stride=*/sublane_stride, |
| /*sublanes_per_stride=*/1, |
| /*sublane_mask=*/sublane_mask_v); |
| } else { |
| rewriter.replaceOpWithNewOp<llo::VectorStoreOp>( |
| op, /*address=*/base, |
| /*displacement=*/displacement, |
| /*to_store=*/subst.getValueToStore(), |
| /*sublane_mask=*/sublane_mask_v, |
| /*sublane_stride=*/sublane_stride); |
| } |
| } else { |
| Value addr = base_addr; |
| if (*offset) { |
| addr = rewriter.create<llo::ScalarAddressVmemOp>(op.getLoc(), |
| addr, *offset); |
| } |
| if (auto mask = op.getMask()) { |
| rewriter |
| .replaceOpWithNewOp<llo::VstMaskedWithArbitrarySlaneStrideOp>( |
| op, addr, mask, subst.getValueToStore(), sublane_stride, |
| sublaneCount(), sublane_mask_v); |
| } else { |
| rewriter.replaceOpWithNewOp<llo::VstWithArbitrarySlaneStrideOp>( |
| op, addr, subst.getValueToStore(), sublane_stride, |
| sublaneCount(), sublane_mask_v); |
| } |
| } |
| return success(); |
| }); |
| addPattern<tpu::MatmulOp>( |
| patterns, |
| [this](tpu::MatmulOp op, tpu::MatmulOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) -> LogicalResult { |
| bool transposed = op.getTransposeRhs(); |
| |
| bool high_precision = false; |
| if (auto precision = op.getPrecision()) { |
| switch (*precision) { |
| case tpu::ContractPrecision::kBF16: |
| high_precision = false; |
| break; |
| case tpu::ContractPrecision::kFP32: |
| high_precision = true; |
| break; |
| } |
| } |
| |
| CHECK_EQ(target_->MxuContractingSize(), |
| target_->MxuNoncontractingSize()); |
| int64_t mxu_size = target_->MxuContractingSize(); |
| CHECK_EQ(mxu_size % laneCount(), 0); |
| |
| auto lhs_type = cast<VectorType>(subst.getLhs().getType()); |
| int64_t lhs_packing = 32 / lhs_type.getElementTypeBitWidth(); |
| auto lhs_shape = lhs_type.getShape(); |
| if (lhs_shape[0] <= 0 || |
| lhs_shape[0] % (sublaneCount() * lhs_packing) != 0 || |
| lhs_shape[1] != mxu_size || |
| (high_precision && !lhs_type.getElementType().isF32())) { |
| op.emitOpError() << "Bad lhs type in tpu.matmul"; |
| return failure(); |
| } |
| mlir::ValueRange lhs_vectors; |
| if (auto rvop = subst.getLhs().getDefiningOp<tpu::RollVectorsOp>()) { |
| lhs_vectors = rvop->getOperands(); |
| } else { |
| return failure(); |
| } |
| |
| auto rhs_type = cast<VectorType>(subst.getRhs().getType()); |
| int64_t rhs_packing = 32 / rhs_type.getElementTypeBitWidth(); |
| if (rhs_type.getShape() != ArrayRef<int64_t>{mxu_size, mxu_size} || |
| (high_precision && !rhs_type.getElementType().isF32())) { |
| op.emitOpError() << "Bad rhs type in tpu.matmul"; |
| return failure(); |
| } |
| mlir::ValueRange rhs_vectors; |
| if (auto rvop = subst.getRhs().getDefiningOp<tpu::RollVectorsOp>()) { |
| rhs_vectors = rvop->getOperands(); |
| CHECK_EQ(rhs_vectors.size(), |
| mxu_size * mxu_size / |
| (sublaneCount() * rhs_packing * laneCount())); |
| } else { |
| return failure(); |
| } |
| |
| auto acc_type = cast<VectorType>(subst.getAcc().getType()); |
| if (acc_type.getShape() != lhs_type.getShape() || |
| acc_type.getElementTypeBitWidth() != 32) { |
| op.emitOpError() << "Bad acc type in tpu.matmul"; |
| return failure(); |
| } |
| std::vector<Value> acc_vectors; |
| if (auto rvop = subst.getAcc().getDefiningOp<tpu::RollVectorsOp>()) { |
| acc_vectors.insert(acc_vectors.end(), rvop.operand_begin(), |
| rvop.operand_end()); |
| } else { |
| return failure(); |
| } |
| int64_t acc_lhs_compression = acc_type.getElementTypeBitWidth() / |
| lhs_type.getElementTypeBitWidth(); |
| CHECK_EQ(acc_vectors.size(), |
| lhs_vectors.size() * acc_lhs_compression); |
| |
| // Make sure the type of the accumulator matches the operand types. |
| bool integral = acc_type.getElementType().isSignlessInteger(); |
| if (integral) { |
| if (!lhs_type.getElementType().isSignlessInteger(8) || |
| !rhs_type.getElementType().isSignlessInteger(8)) { |
| return failure(); |
| } |
| } else { |
| if (!isa<BFloat16Type, Float32Type>(lhs_type.getElementType()) || |
| !isa<BFloat16Type, Float32Type>(rhs_type.getElementType())) { |
| return failure(); |
| } |
| } |
| |
| // The code below assumes MXU size is either 128 or 256. |
| if (mxu_size != 128 && mxu_size != 256) { |
| op.emitOpError() |
| << "Not implemented: unsupported MXU size " << mxu_size; |
| return failure(); |
| } |
| |
| auto loc = op.getLoc(); |
| auto push_latch = [&rhs_vectors = std::as_const(rhs_vectors), |
| &rewriter, mxu_size, loc, |
| this](llo::GainLatchMode mode) { |
| // Here tile refers to 128x128 tile. |
| auto tile_cols = mxu_size / laneCount(); |
| auto tile_rows = tile_cols; |
| auto vregs_per_tile = rhs_vectors.size() / (tile_cols * tile_rows); |
| bool reverse_push = target_->ShouldReverseTileForLatching(); |
| |
| for (int tr = 0; tr < tile_rows; ++tr) { |
| // Vreg index in the 128x128 tile. |
| int vi = reverse_push ? vregs_per_tile - 1 : 0; |
| while (vi >= 0 && vi < vregs_per_tile) { |
| for (int tc = tile_cols - 1; tc >= 0; --tc) { |
| int idx = |
| tr * tile_cols * vregs_per_tile + vi * tile_cols + tc; |
| Value rhs_vector = rhs_vectors[idx]; |
| // JF and DF want to be fed the latch with reversed sublanes. |
| if (target_->ShouldReverseForLatching()) { |
| rhs_vector = rewriter.create<llo::VectorSublaneReverseOp>( |
| loc, rhs_vector.getType(), rhs_vector); |
| } |
| if (tile_cols == 1) { |
| rewriter.create<llo::VectorLatchOp>(loc, rhs_vector, mode); |
| } else if (tc == 0) { |
| // On GFC, we switched from vlatch 1/2 to just vlatch1. |
| int variant = (reverse_push && tr > 0) ? 2 : 1; |
| rewriter.create<llo::VectorLatchIOp>(loc, rhs_vector, |
| variant, mode); |
| } else { |
| rewriter.create<llo::VectorMatprepSubrOp>(loc, rhs_vector, |
| mode); |
| } |
| } |
| vi = reverse_push ? vi - 1 : vi + 1; |
| } |
| } |
| }; |
| // Watch out! Calling this function mutates acc_vectors |
| // by accumulating the matmul result into them. |
| auto matmul = [&lhs_vectors = std::as_const(lhs_vectors), |
| &acc_vectors, &rewriter, mxu_size, integral, lhs_type, |
| loc, acc_lhs_compression, this](llo::MatmulMode mode) { |
| llo::MatmulDataFormat data_format = llo::MatmulDataFormat::kF32; |
| if (lhs_type.getElementType().isBF16()) { |
| data_format = llo::MatmulDataFormat::kBf16; |
| CHECK(mode == llo::MatmulMode::kRound); |
| } else if (lhs_type.getElementType().isSignlessInteger(8)) { |
| data_format = llo::MatmulDataFormat::kS8; |
| CHECK(mode == llo::MatmulMode::kRound); |
| } else if (!lhs_type.getElementType().isF32()) { |
| LOG(FATAL) << "Unexpected LHS type"; |
| } |
| |
| auto col_cnt = mxu_size / laneCount(); |
| auto row_cnt = lhs_vectors.size() / col_cnt; |
| for (int r = 0; r < row_cnt; ++r) { |
| for (int c = col_cnt - 1; c >= 0; --c) { |
| Value lhs_vector = lhs_vectors[r * col_cnt + c]; |
| if (col_cnt == 1) { |
| rewriter.create<llo::VectorMatmulOp>( |
| loc, lhs_vector, mode, |
| /*mxu_id=*/0, |
| /*dwg=*/false, |
| /*data_format=*/data_format); |
| } else if (c == 0) { |
| rewriter.create<llo::VectorMatmulMubrOp>( |
| loc, lhs_vector, mode, |
| /*mxu_id=*/0, |
| /*dwg=*/false, |
| /*data_format=*/data_format); |
| } else { |
| rewriter.create<llo::VectorMatprepMubrOp>(loc, lhs_vector, |
| mode); |
| } |
| } |
| // Now the matmul is done, we can accumulate the result to acc. |
| for (int ur = 0; ur < acc_lhs_compression; ++ur) { |
| for (int c = 0; c < col_cnt; ++c) { |
| int idx = |
| r * acc_lhs_compression * col_cnt + ur * col_cnt + c; |
| Value result = rewriter.create<llo::VectorMatresOp>( |
| loc, acc_vectors[idx].getType(), data_format); |
| if (integral) { |
| acc_vectors[idx] = rewriter.create<llo::VectorAddS32Op>( |
| loc, acc_vectors[idx], result); |
| } else { |
| acc_vectors[idx] = rewriter.create<llo::VectorAddF32Op>( |
| loc, acc_vectors[idx], result); |
| } |
| } |
| } |
| } |
| }; |
| |
| using llo::MatmulMode; |
| using llo::GainLatchMode; |
| using llo::MatmulDataFormat; |
| if (high_precision) { |
| // This path is only taken for fp32 inputs. All targets support |
| // the matmul and latch modes used here, so no checks are necessary. |
| using MatmulInstance = std::pair<MatmulMode, GainLatchMode>; |
| static const std::array<MatmulInstance, 6> matmul_sequence = { |
| MatmulInstance{MatmulMode::kSoftLowEight, |
| GainLatchMode::kNoXposeHiF32}, |
| MatmulInstance{MatmulMode::kHigh, |
| GainLatchMode::kNoXposeSoftLowEightF32}, |
| MatmulInstance{MatmulMode::kLow, GainLatchMode::kNoXposeLowF32}, |
| MatmulInstance{MatmulMode::kSoftMiddleEight, |
| GainLatchMode::kNoXposeHiF32}, |
| MatmulInstance{MatmulMode::kHigh, |
| GainLatchMode::kNoXposeSoftMiddleEightF32}, |
| MatmulInstance{MatmulMode::kHigh, GainLatchMode::kNoXposeHiF32}, |
| }; |
| std::optional<GainLatchMode> last_latch_mode = std::nullopt; |
| for (auto [matmul_mode, latch_mode_no_transp] : matmul_sequence) { |
| if (!last_latch_mode.has_value() || |
| *last_latch_mode != latch_mode_no_transp) { |
| if (last_latch_mode.has_value()) { |
| rewriter.create<llo::VectorDoneWithGainsOp>(op.getLoc()); |
| } |
| GainLatchMode latch_mode = latch_mode_no_transp; |
| if (transposed) { |
| latch_mode = static_cast<GainLatchMode>( |
| static_cast<int32_t>(latch_mode) + 1); |
| } |
| push_latch(latch_mode); |
| last_latch_mode = latch_mode_no_transp; |
| } |
| matmul(matmul_mode); |
| } |
| } else { |
| GainLatchMode latch_mode; |
| MatmulDataFormat data_format; |
| if (rhs_type.getElementType().isF32()) { |
| latch_mode = transposed ? GainLatchMode::kXposeF32 |
| : GainLatchMode::kNoXposeF32; |
| data_format = MatmulDataFormat::kF32; |
| } else if (rhs_type.getElementType().isBF16()) { |
| latch_mode = transposed ? GainLatchMode::kXposePackedBf16 |
| : GainLatchMode::kNoXposePackedBf16; |
| data_format = MatmulDataFormat::kBf16; |
| } else if (rhs_type.getElementType().isSignlessInteger(8)) { |
| latch_mode = transposed ? GainLatchMode::kXposeS8 |
| : GainLatchMode::kNoXposeS8; |
| data_format = MatmulDataFormat::kS8; |
| } else { |
| return failure(); |
| } |
| if (!target_->SupportsMatmulDataFormat( |
| static_cast<xla::jellyfish::MatmulDataFormat>( |
| data_format)) || |
| !target_->SupportsGainLatchMode( |
| static_cast<xla::jellyfish::GainLatchMode>(latch_mode))) { |
| op->emitOpError( |
| "Unsupported input data type in matrix multiplication."); |
| return failure(); |
| } |
| push_latch(latch_mode); |
| matmul(MatmulMode::kRound); |
| } |
| |
| rewriter.create<llo::VectorDoneWithGainsOp>(op.getLoc()); |
| rewriter.replaceOpWithNewOp<tpu::RollVectorsOp>(op, lhs_type, |
| acc_vectors); |
| return success(); |
| }); |
| addPattern<tpu::RotateOp>( |
| patterns, [this](tpu::RotateOp op, tpu::RotateOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (!isRepresentableVectorType(op.getType())) { |
| return failure(); |
| } |
| int32_t amount = op.getAmount(); |
| if (op.getDimension() == 0) { |
| if (amount < 0 || amount >= sublaneCount()) { |
| return failure(); |
| } |
| if (op.getStride().has_value() || |
| op.getStrideDimension().has_value()) { |
| op.emitOpError("Not implemented: rotate sublanes with stride."); |
| return failure(); |
| } |
| if (amount % sublaneCount() == 0) { |
| rewriter.replaceOp(op, subst.getValue()); |
| return success(); |
| } |
| // NOTE: LLO rotates sublanes down. |
| rewriter.replaceOpWithNewOp<llo::VectorSublaneRotateOp>( |
| op, subst.getValue(), sublaneCount() - amount); |
| return success(); |
| } |
| if (op.getDimension() == 1) { |
| if (amount < 0 || amount >= laneCount()) { |
| return failure(); |
| } |
| if (op.getStride().has_value() && |
| op.getStrideDimension().value_or(1) != 0) { |
| op.emitOpError("Expect stride dimension is 0"); |
| return failure(); |
| } |
| if (amount % laneCount() == 0 && op.getStride().value_or(0) == 0) { |
| rewriter.replaceOp(op, subst.getValue()); |
| return success(); |
| } |
| rewriter.replaceOpWithNewOp<llo::VectorRotateOp>( |
| op, subst.getValue(), amount, op.getStrideAttr()); |
| return success(); |
| } |
| return failure(); |
| }); |
| addPattern<tpu::GatherOp>( |
| patterns, [this](tpu::GatherOp op, tpu::GatherOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| auto indices = op.getIndices(); |
| if (!isRepresentableVectorType(op.getType()) || |
| op.getDimension() != 0 || indices.size() != sublaneCount()) { |
| return failure(); |
| } |
| bool constant = true; |
| for (int32_t idx : indices) { |
| if (idx < 0 || idx >= sublaneCount()) { |
| return failure(); |
| } |
| constant &= idx == indices[0]; |
| } |
| if (constant) { |
| rewriter.replaceOpWithNewOp<llo::VectorSublaneReplicateOp>( |
| op, op.getType(), op.getSource(), |
| rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(indices[0]))); |
| } else { |
| rewriter.replaceOpWithNewOp<llo::VectorSublaneShuffleOp>( |
| op, op.getType(), op.getSource(), op.getIndicesAttr()); |
| } |
| return success(); |
| }); |
| addPattern<tpu::DynamicGatherOp>( |
| patterns, |
| [this](tpu::DynamicGatherOp op, tpu::DynamicGatherOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (!isRepresentableVectorType(op.getType())) { |
| return failure(); |
| } |
| if (op.getDimension() != 1) { |
| return failure(); |
| } |
| auto set_permute = rewriter.create<llo::VectorSetPermutePatternOp>( |
| op.getLoc(), op.getType(), subst.getIndices(), |
| llo::SetPermuteMode::kOneSublane, /*xlu_id=*/0, |
| /*source_bus=*/nullptr); |
| auto request = rewriter.create<llo::VectorPermuteOp>( |
| op.getLoc(), op.getType(), subst.getSource(), set_permute, |
| /*xlu_id=*/0, |
| /*source_bus=*/nullptr); |
| rewriter.replaceOpWithNewOp<llo::VectorPermuteResultOp>( |
| op, op.getType(), request, 0); |
| return success(); |
| }); |
| addPattern<tpu::IotaOp>( |
| patterns, [this](tpu::IotaOp op, tpu::IotaOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (!isRepresentableVectorType(op.getType()) || |
| op.getType().getElementType() != rewriter.getI32Type()) { |
| return failure(); |
| } |
| if (op.getDimension() == 0) { |
| rewriter.replaceOpWithNewOp<llo::VectorSublaneId>(op, op.getType()); |
| } else if (op.getDimension() == 1) { |
| rewriter.replaceOpWithNewOp<llo::VectorLaneId>(op, op.getType()); |
| } else if (!op.getDimension().has_value()) { |
| rewriter.replaceOpWithNewOp<llo::VectorLaneSeqOp>(op, op.getType()); |
| } else { |
| return failure(); |
| } |
| return success(); |
| }); |
| // EraseLayoutOp exists only to appease some silly canonicalization rules |
| // and has no operational meaning. |
| addPattern<tpu::EraseLayoutOp>( |
| patterns, [](tpu::EraseLayoutOp op, tpu::EraseLayoutOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.replaceOp(op, subst.getOperand()); |
| return success(); |
| }); |
| // MemRefSqueezeOp is a type cast and has no operational meaning. |
| addPattern<tpu::MemRefSqueezeOp>( |
| patterns, [](tpu::MemRefSqueezeOp op, tpu::MemRefSqueezeOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.replaceOp(op, subst.getInput()); |
| return success(); |
| }); |
| // ReinterpretCastOp is a type cast and has no operational meaning. |
| addPattern<tpu::ReinterpretCastOp>( |
| patterns, |
| [](tpu::ReinterpretCastOp op, tpu::ReinterpretCastOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.replaceOp(op, subst.getInput()); |
| return success(); |
| }); |
| addPattern<tpu::PackSubelementsOp>( |
| patterns, |
| [this](tpu::PackSubelementsOp op, tpu::PackSubelementsOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| auto sources = subst.getSources(); |
| auto packed_ty = op.getType().getElementType(); |
| int packing = target_->VectorScalarBitWidth() / |
| packed_ty.getIntOrFloatBitWidth(); |
| if (sources.size() != packing) { |
| return failure(); |
| } |
| auto source_ty = cast<VectorType>(sources.front().getType()); |
| if (!isRepresentableVectorType(source_ty) || |
| !isRepresentableVectorType(op.getType())) { |
| return failure(); |
| } |
| if (packed_ty.isBF16() && source_ty.getElementType().isF32()) { |
| if (target_->SupportsVectorPackOps( |
| xla::jellyfish::VpackFormat::kCompressedBf16)) { |
| rewriter.replaceOpWithNewOp<llo::VectorPackOp>( |
| op, op.getType(), llo::VpackFormat::kCompressedBf16, |
| sources[1], sources[0]); |
| return success(); |
| } |
| auto exe_region = |
| rewriter.create<llo::RegionOp>(op.getLoc(), op.getType()); |
| Block &block = exe_region.getRegion().emplaceBlock(); |
| { |
| ConversionPatternRewriter::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPointToStart(&block); |
| // TODO(b/280782839): Fold this into successive stores. |
| auto addr = rewriter.create<llo::AllocaVmemOp>( |
| op.getLoc(), |
| target_->ChunkSizeBytes() / target_->VmemWordSizeBytes()); |
| for (int i = 0; i < sources.size(); ++i) { |
| rewriter.create<llo::VectorStoreAndPackX16Op>( |
| op.getLoc(), addr, i, sources[i], rewriter.getBF16Type()); |
| } |
| auto packed = rewriter.create<llo::VectorLoadOp>( |
| op.getLoc(), op.getType(), addr, nullptr, nullptr); |
| rewriter.create<llo::YieldOp>(op.getLoc(), packed.getResult()); |
| } |
| rewriter.replaceOp(op, exe_region.getResult(0)); |
| return success(); |
| } |
| if (source_ty.getElementType().isSignlessInteger(32) && |
| packed_ty.isSignlessInteger()) { |
| unsigned bitwidth = 32; |
| std::vector<Value> sources(op.getSources().begin(), |
| op.getSources().end()); |
| std::vector<Value> next_sources; |
| next_sources.reserve(sources.size() / 2); |
| while (bitwidth > packed_ty.getIntOrFloatBitWidth()) { |
| CHECK_EQ(sources.size(), |
| bitwidth / packed_ty.getIntOrFloatBitWidth()); |
| llo::VpackFormat format = GetCompressedByteFormat(bitwidth / 2); |
| if (format == llo::VpackFormat::kInvalid) { |
| op.emitOpError() << "unsupported source element type " |
| << source_ty.getElementType(); |
| return failure(); |
| } |
| if (!target_->SupportsVectorPackOps( |
| static_cast<xla::jellyfish::VpackFormat>(format))) { |
| op.emitOpError() |
| << "Target TPU does not support required instructions"; |
| return failure(); |
| } |
| auto packed_ty = |
| VectorType::get(nativeVectorShape(bitwidth / 2), |
| rewriter.getIntegerType(bitwidth / 2)); |
| for (int i = 0; i < sources.size(); i += 2) { |
| next_sources.push_back(rewriter.create<llo::VectorPackOp>( |
| op.getLoc(), packed_ty, format, sources[i + 1], |
| sources[i])); |
| } |
| bitwidth /= 2; |
| std::swap(sources, next_sources); |
| next_sources.clear(); |
| } |
| CHECK_EQ(bitwidth, packed_ty.getIntOrFloatBitWidth()); |
| CHECK_EQ(sources.size(), 1); |
| rewriter.replaceOp(op, sources.front()); |
| return success(); |
| } |
| op.emitOpError() << "unsupported source element type " |
| << source_ty.getElementType(); |
| return failure(); |
| }); |
| addPattern<tpu::UnpackSubelementsOp>( |
| patterns, [this](tpu::UnpackSubelementsOp op, |
| tpu::UnpackSubelementsOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| auto unpacked_ty = op.getType().getElementType(); |
| auto source_ty = op.getSource().getType(); |
| if (!isRepresentableVectorType(source_ty) || |
| !isRepresentableVectorType(op.getType())) { |
| return failure(); |
| } |
| llo::VpackFormat format = llo::VpackFormat::kInvalid; |
| if (unpacked_ty.isSignlessInteger(32) && |
| source_ty.getElementType().isSignlessInteger()) { |
| format = |
| GetCompressedByteFormat(source_ty.getElementTypeBitWidth()); |
| } else if (unpacked_ty.isF32() && |
| source_ty.getElementType().isBF16()) { |
| format = llo::VpackFormat::kCompressedBf16; |
| } |
| if (format == llo::VpackFormat::kInvalid) { |
| op.emitOpError() << "unsupported source element type " |
| << source_ty.getElementType(); |
| return failure(); |
| } |
| if (target_->SupportsVectorUnpackOps( |
| static_cast<xla::jellyfish::VpackFormat>(format))) { |
| rewriter.replaceOpWithNewOp<llo::VectorUnpackOp>( |
| op, op.getType(), op.getIndex(), format, subst.getSource()); |
| } else { |
| auto exe_region = |
| rewriter.create<llo::RegionOp>(op.getLoc(), op.getType()); |
| Block &block = exe_region.getRegion().emplaceBlock(); |
| { |
| ConversionPatternRewriter::InsertionGuard guard(rewriter); |
| rewriter.setInsertionPointToStart(&block); |
| // TODO(b/280782839): We only need part of this space. |
| // TODO(b/280782839): Fold this into preceding loads. |
| auto addr = rewriter.create<llo::AllocaVmemOp>( |
| op.getLoc(), |
| target_->ChunkSizeBytes() / target_->VmemWordSizeBytes()); |
| rewriter.create<llo::VectorStoreOp>(op.getLoc(), addr, nullptr, |
| subst.getSource(), nullptr); |
| Type signed_source_element_ty = source_ty.getElementType(); |
| if (signed_source_element_ty.isSignlessInteger()) { |
| signed_source_element_ty = rewriter.getIntegerType( |
| source_ty.getElementTypeBitWidth(), true); |
| } |
| Value unpacked = rewriter.create<llo::VectorLoadAndUnpackOp>( |
| op.getLoc(), op.getType(), addr, op.getIndex(), |
| signed_source_element_ty); |
| rewriter.create<llo::YieldOp>(op.getLoc(), unpacked); |
| } |
| rewriter.replaceOp(op, exe_region.getResult(0)); |
| } |
| return success(); |
| }); |
| addPattern<tpu::AllocaSemaphoreOp>( |
| patterns, |
| [](tpu::AllocaSemaphoreOp op, tpu::AllocaSemaphoreOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) -> LogicalResult { |
| MemRefType ref_ty = op.getResult().getType(); |
| auto layout = |
| dyn_cast<mlir::tpu::TiledLayoutAttr>(ref_ty.getLayout()); |
| if (!layout) { |
| return op.emitOpError() << "need a tiled layout"; |
| } |
| if (!layout.getTiles().empty()) { |
| return op.emitOpError() << "tiling unsupported"; |
| } |
| auto tile_strides = layout.getTileStrides(); |
| if (tile_strides.size() != ref_ty.getRank()) { |
| return failure(); |
| } |
| int stride = 1; |
| for (int i = ref_ty.getRank() - 1; i >= 0; --i) { |
| if (tile_strides[i] != stride) { |
| return op.emitOpError() |
| << "non-contiguous allocations not supported"; |
| } |
| if (ref_ty.isDynamicDim(i)) { |
| return op->emitOpError() |
| << "dynamic shapes not supported in allocations"; |
| } |
| stride *= ref_ty.getDimSize(i); |
| } |
| rewriter.replaceOpWithNewOp<llo::AllocaSyncFlagOp>( |
| op, rewriter.getI32Type(), stride); |
| return success(); |
| }); |
| addPattern<tpu::DeviceIdOp>( |
| patterns, [](tpu::DeviceIdOp op, tpu::DeviceIdOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.replaceOpWithNewOp<llo::LogicalDeviceIdOp>( |
| op, rewriter.getI32Type()); |
| return success(); |
| }); |
| addPattern<tpu::SemaphoreReadOp>( |
| patterns, [](tpu::SemaphoreReadOp op, tpu::SemaphoreReadOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.replaceOpWithNewOp<llo::VSyncRead>(op, subst.getSemaphore()); |
| return success(); |
| }); |
| addPattern<tpu::SemaphoreWaitOp>( |
| patterns, [](tpu::SemaphoreWaitOp op, tpu::SemaphoreWaitOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| auto c0 = rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(0)); |
| auto neg_amount = rewriter.create<llo::ScalarSubS32Op>( |
| op.getLoc(), rewriter.getI32Type(), c0, subst.getAmount()); |
| // This could race if multiple cores were to wait on the same |
| // semaphore, but I don't think this can happen. |
| rewriter.create<llo::VSyncAddOp>(op.getLoc(), subst.getSemaphore(), |
| neg_amount); |
| rewriter.replaceOpWithNewOp<llo::VWaitGeOp>(op, subst.getSemaphore(), |
| c0); |
| return success(); |
| }); |
| addPattern<tpu::SemaphoreSignalOp>( |
| patterns, |
| [this](tpu::SemaphoreSignalOp op, tpu::SemaphoreSignalOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (!op.getDeviceId()) { |
| rewriter.replaceOpWithNewOp<llo::VSyncAddOp>( |
| op, subst.getSemaphore(), subst.getAmount()); |
| } else { |
| // If there are multiple tensor cores, then we might need to decode |
| // device id and core index from subst's device_id. |
| if (target_->TensorCoresPerChip() > 1 && |
| !unsafe_allow_multicore_remote_dma_) { |
| return failure(); |
| } |
| auto core_index = rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(0)); |
| rewriter.replaceOpWithNewOp<llo::VSyncAddRemoteOp>( |
| op, subst.getSemaphore(), subst.getDeviceId(), core_index, |
| subst.getAmount()); |
| } |
| return success(); |
| }); |
| addPattern<tpu::EnqueueDMAOp>( |
| patterns, |
| [this](tpu::EnqueueDMAOp op, tpu::EnqueueDMAOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) -> LogicalResult { |
| // If there are multiple tensor cores, then we might need to decode |
| // device id and core index from subst's device_id. Local DMAs are ok. |
| if (op.getDeviceId() && target_->TensorCoresPerChip() > 1 && |
| !unsafe_allow_multicore_remote_dma_) { |
| op.emitOpError("Multicore remote DMAs are not supported."); |
| } |
| auto src_ty = getMemRefType(op.getSource()); |
| auto tgt_ty = getMemRefType(op.getTarget()); |
| auto [src_addr, src_dynamic_sizes] = |
| unpackMemRef(subst.getSource(), rewriter); |
| auto [tgt_addr, tgt_dynamic_sizes] = |
| unpackMemRef(subst.getTarget(), rewriter); |
| if (src_ty.getElementType() != tgt_ty.getElementType()) { |
| op.emitOpError("DMA source and target element type mismatch."); |
| return failure(); |
| } |
| int64_t bytewidth = src_ty.getElementTypeBitWidth() / 8; |
| if (src_ty.getShape() != tgt_ty.getShape()) { |
| return op.emitOpError("DMA source and target shape mismatch."); |
| } |
| // Check that dynamic sizes match too. |
| CHECK_EQ(src_dynamic_sizes.size(), tgt_dynamic_sizes.size()); |
| for (auto [src_size, dst_size] : |
| llvm::zip(src_dynamic_sizes, tgt_dynamic_sizes)) { |
| rewriter.create<llo::ErrorIfOp>( |
| op.getLoc(), |
| rewriter.create<llo::ScalarCmpNeS32Op>(op.getLoc(), src_size, |
| dst_size), |
| rewriter.getStringAttr("Dynamic shape mismatch")); |
| } |
| |
| ArrayRef<int64_t> shape = src_ty.getShape(); |
| // Note that memory space can differ. Only layout left to be checked. |
| auto src_layout = cast<tpu::TiledLayoutAttr>(src_ty.getLayout()); |
| auto tgt_layout = cast<tpu::TiledLayoutAttr>(tgt_ty.getLayout()); |
| // Strides can differ, but tiling should be the same. |
| if (src_layout.getTiles() != tgt_layout.getTiles()) { |
| op.emitOpError("DMA source and target tiling mismatch."); |
| return failure(); |
| } |
| auto tiling = src_layout.getTiles().front().dimensions(); |
| |
| SmallVector<Value> tiled_dynamic_shape; |
| if (failed(getTiledDynamicShape(&tiled_dynamic_shape, shape, |
| src_dynamic_sizes, tiling, rewriter, |
| op))) { |
| return failure(); |
| } |
| |
| int32_t tile_bytes = bytewidth; |
| for (int64_t s : tiling) { |
| tile_bytes *= s; |
| } |
| if (tile_bytes % 512 != 0) { |
| return op.emitOpError("Tile size must be divisible by 512 bytes."); |
| } |
| Value num_elements = |
| getDynamicNumElements(shape, src_dynamic_sizes, rewriter); |
| Value size_512b = rewriter.create<llo::ScalarDivS32Op>( |
| op.getLoc(), rewriter.getI32Type(), num_elements, |
| rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(512 / bytewidth))); |
| |
| SmallVector<Value> steps_per_stride = tiled_dynamic_shape; |
| steps_per_stride.back() = rewriter.create<llo::ScalarMulS32Op>( |
| op.getLoc(), rewriter.getI32Type(), steps_per_stride.back(), |
| rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(tile_bytes))); |
| auto get_strides = [&](TiledLayoutAttr layout) |
| -> mlir::FailureOr<SmallVector<int32_t>> { |
| if (layout.getTileStrides().back() != 1) { |
| op.emitOpError("expected a 1 stride"); |
| return failure(); |
| } |
| SmallVector<int32_t> strides_in_bytes; |
| strides_in_bytes.reserve(layout.getTileStrides().size()); |
| for (int64_t stride : layout.getTileStrides()) { |
| strides_in_bytes.push_back(stride * tile_bytes); |
| } |
| return strides_in_bytes; |
| }; |
| auto src_strides = get_strides(src_layout); |
| auto dst_strides = get_strides(tgt_layout); |
| if (failed(src_strides) || failed(dst_strides)) { |
| return failure(); |
| } |
| CHECK_EQ(steps_per_stride.size(), src_strides->size()); |
| CHECK_EQ(steps_per_stride.size(), dst_strides->size()); |
| auto src_tile_strides = src_layout.getTileStrides(); |
| auto tgt_tile_strides = tgt_layout.getTileStrides(); |
| // We try to flatten pairs of dimensions that don't need to use |
| // multi-level striding. |
| for (int i = steps_per_stride.size() - 2; i >= 0; --i) { |
| int64_t tiled_size; |
| if (auto cst = tiled_dynamic_shape[i + 1] |
| .getDefiningOp<llo::ConstantOp>()) { |
| tiled_size = cast<IntegerAttr>(cst.getValue()).getInt(); |
| } else { |
| continue; |
| } |
| if (src_tile_strides[i] == tiled_size * src_tile_strides[i + 1] && |
| tgt_tile_strides[i] == tiled_size * tgt_tile_strides[i + 1]) { |
| steps_per_stride[i + 1] = rewriter.create<llo::ScalarMulS32Op>( |
| op.getLoc(), rewriter.getI32Type(), steps_per_stride[i + 1], |
| steps_per_stride[i]); |
| steps_per_stride.erase(steps_per_stride.begin() + i); |
| src_strides->erase(src_strides->begin() + i); |
| dst_strides->erase(dst_strides->begin() + i); |
| } |
| } |
| // Remove the dimensions XLA expects to be implicit. |
| src_strides->pop_back(); |
| dst_strides->pop_back(); |
| steps_per_stride.erase(steps_per_stride.begin()); |
| if (steps_per_stride.size() > target_->StrideLevelCount()) { |
| op.emitOpError("Not implemented: DMA is too complicated."); |
| return failure(); |
| } |
| |
| Value core_index = nullptr; |
| if (subst.getDeviceId()) { |
| core_index = rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(0)); |
| } |
| rewriter.replaceOpWithNewOp<llo::EnqueueDMAOp>( |
| op, src_addr, subst.getSourceSemaphore(), size_512b, tgt_addr, |
| subst.getTargetSemaphore(), |
| rewriter.getDenseI32ArrayAttr(*src_strides), |
| rewriter.getDenseI32ArrayAttr(*dst_strides), steps_per_stride, |
| subst.getDeviceId(), core_index); |
| return success(); |
| }); |
| addPattern<tpu::WaitDMAOp>( |
| patterns, |
| [this](tpu::WaitDMAOp op, tpu::WaitDMAOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) -> LogicalResult { |
| auto dma_ty = getMemRefType(op.getRef()); |
| auto dma_layout = cast<tpu::TiledLayoutAttr>(dma_ty.getLayout()); |
| auto tiling = dma_layout.getTiles().front().dimensions(); |
| ArrayRef<int64_t> shape = dma_ty.getShape(); |
| auto [_, dynamic_sizes] = unpackMemRef(subst.getRef(), rewriter); |
| |
| SmallVector<Value> tiled_dynamic_shape; |
| if (failed(getTiledDynamicShape(&tiled_dynamic_shape, shape, |
| dynamic_sizes, tiling, rewriter, |
| op))) { |
| return failure(); |
| } |
| |
| int32_t tile_bytes = dma_ty.getElementTypeBitWidth() / 8; |
| for (int64_t s : tiling) { |
| tile_bytes *= s; |
| } |
| if (tile_bytes % 512 != 0) { |
| return op.emitOpError("Tile size must be divisible by 512 bytes."); |
| } |
| Value num_elements = |
| getDynamicNumElements(shape, dynamic_sizes, rewriter); |
| Value size_512b = rewriter.create<llo::ScalarDivS32Op>( |
| op.getLoc(), rewriter.getI32Type(), num_elements, |
| rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr( |
| 512 / (dma_ty.getElementTypeBitWidth() / 8)))); |
| |
| rewriter.replaceOpWithNewOp<llo::DMADoneOp>( |
| op, size_512b, subst.getSemaphore(), |
| hasMemorySpace(dma_ty, tpu::MemorySpace::smem)); |
| return success(); |
| }); |
| addPattern<tpu::MemRefSliceOp>( |
| patterns, |
| [this](tpu::MemRefSliceOp op, tpu::MemRefSliceOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) -> LogicalResult { |
| auto indices = subst.getBaseIdx(); |
| auto slice_shape = op.getResult().getType().getShape(); |
| auto source_ty = getMemRefType(op.getMemRef()); |
| if (!source_ty.hasStaticShape()) { |
| return op.emitOpError( |
| "Only slicing of memrefs with static shapes is supported."); |
| } |
| auto source_shape = source_ty.getShape(); |
| bool is_semaphore = |
| hasMemorySpace(source_ty, tpu::MemorySpace::kSemaphoreMem); |
| if (is_semaphore && !isa<SemaphoreType, DMASemaphoreType>( |
| source_ty.getElementType())) { |
| return op.emitOpError( |
| "References to semaphore memory space must have a semaphore " |
| "element type."); |
| } |
| if (indices.size() != slice_shape.size() || |
| indices.size() != source_shape.size()) { |
| op.emitOpError("Indices and slice shapes must match."); |
| return failure(); |
| } |
| auto tiled_layout = |
| dyn_cast<tpu::TiledLayoutAttr>(source_ty.getLayout()); |
| if (!tiled_layout) { |
| op.emitOpError("Must have tiled layout."); |
| return failure(); |
| } |
| absl::Span<const int64_t> tile; |
| if (!tiled_layout.getTiles().empty()) { |
| // We assume that only the highest-level tile matters. |
| // All lower-level tiles only rearrange the data within the tile. |
| tile = tiled_layout.getTiles().front().dimensions(); |
| } |
| if (source_shape.size() < tile.size()) { |
| op.emitOpError("Source rank must not be smaller than tile rank."); |
| return failure(); |
| } |
| |
| // Check that sliced portions are aligned to tile boundaries. |
| // We use the values before dialect conversion, since ops like |
| // tpu.assume_multiple get eliminated. |
| auto tiled_indices = op.getBaseIdx().take_back(tile.size()); |
| auto tiled_slice_shape = slice_shape.take_back(tile.size()); |
| bool is_aligned = true; |
| for (int64_t i = 0; i < tile.size(); ++i) { |
| is_aligned &= tiled_slice_shape[i] % tile[i] == 0; |
| if (!isGuaranteedDivisible(tiled_indices[i], tile[i])) { |
| op.emitOpError( |
| "Failed to prove that a tile index is divisible by the " |
| "tiling."); |
| return failure(); |
| } |
| } |
| if (!is_aligned) { |
| op.emitOpError("Slice shape must be aligned to tile boundaries."); |
| return failure(); |
| } |
| |
| auto int_const = [&](int i) -> Value { |
| return rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(i)); |
| }; |
| Value offset = int_const(0); |
| auto tile_strides = tiled_layout.getTileStrides(); |
| int tile_element_size = std::accumulate(tile.begin(), tile.end(), 1, |
| std::multiplies<int>()); |
| // TODO(apaszke): More often then not it's the trailing indices that |
| // are static and not the leading ones. If we computed the sum in the |
| // opposite order, we might expose constant folding opportunities. |
| for (int i = indices.size() - 1; i >= 0; --i) { |
| auto idx = indices[i]; |
| auto dim_size = source_shape[i]; |
| auto tile_i = i - indices.size() + tile.size(); |
| if (tile_i >= 0 && tile_i < tile.size()) { |
| idx = rewriter.create<llo::ScalarDivS32Op>( |
| op.getLoc(), rewriter.getI32Type(), idx, |
| int_const(tile[tile_i])); |
| dim_size /= tile[tile_i]; |
| } |
| auto scaled_index = rewriter.create<llo::ScalarMulS32Op>( |
| op.getLoc(), rewriter.getI32Type(), int_const(tile_strides[i]), |
| idx); |
| offset = rewriter.create<llo::ScalarAddS32Op>( |
| op.getLoc(), rewriter.getI32Type(), offset, scaled_index); |
| } |
| int bitwidth; |
| if (is_semaphore) { |
| bitwidth = target_->SflagWordSizeBytes() * 8; |
| } else { |
| bitwidth = source_ty.getElementTypeBitWidth(); |
| } |
| const auto tile_size_bytes = tile_element_size * bitwidth / 8; |
| auto addr = rewriter.create<llo::AddrScaledOp>( |
| op.getLoc(), subst.getMemRef(), offset, tile_size_bytes); |
| if (op.getResult().getType().hasStaticShape()) { |
| rewriter.replaceOp(op, addr); |
| } else { |
| SmallVector<Value> memref_rep; |
| memref_rep.push_back(addr); |
| memref_rep.append(subst.getDynamicSizes().begin(), |
| subst.getDynamicSizes().end()); |
| rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>( |
| op, op.getType(), memref_rep); |
| } |
| return success(); |
| }); |
| addPattern<tpu::BroadcastInSublanesOp>( |
| patterns, [](tpu::BroadcastInSublanesOp op, |
| tpu::BroadcastInSublanesOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| Value lane = rewriter.create<llo::ConstantOp>( |
| op.getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(op.getLane())); |
| rewriter.replaceOpWithNewOp<llo::VectorBroadcastSublaneChunkOp>( |
| op, op.getType(), subst.getSource(), lane); |
| return success(); |
| }); |
| addPattern<tpu::GetBarrierSemaphoreOp>( |
| patterns, [](tpu::GetBarrierSemaphoreOp op, |
| tpu::GetBarrierSemaphoreOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.replaceOpWithNewOp<llo::GetBarrierSyncFlagOp>(op); |
| return success(); |
| }); |
| addPattern<tpu::TraceOp>( |
| patterns, [](tpu::TraceOp op, tpu::TraceOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| llo::TraceOp new_op = rewriter.create<llo::TraceOp>( |
| op.getLoc(), TypeRange{}, op.getMessageAttr(), op.getLevelAttr()); |
| CHECK_EQ(new_op->getNumRegions(), 1); |
| rewriter.cloneRegionBefore(op.getRegion(), new_op.getRegion(), |
| new_op.getRegion().end()); |
| rewriter.eraseOp(op); |
| return success(); |
| }); |
| addPattern<tpu::TraceStartOp>( |
| patterns, [](tpu::TraceStartOp op, tpu::TraceStartOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.replaceOpWithNewOp<llo::TraceStartOp>( |
| op, op.getMessageAttr(), op.getLevelAttr()); |
| return success(); |
| }); |
| addPattern<tpu::TraceStopOp>( |
| patterns, [](tpu::TraceStopOp op, tpu::TraceStopOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.replaceOpWithNewOp<llo::TraceStopOp>(op); |
| return success(); |
| }); |
| addPattern<tpu::RegionOp>( |
| patterns, [](tpu::RegionOp op, tpu::RegionOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| llo::RegionOp new_op = |
| rewriter.create<llo::RegionOp>(op.getLoc(), TypeRange{}); |
| if (new_op->getNumRegions() != 1) { |
| return failure(); |
| } |
| rewriter.inlineRegionBefore(op.getRegion(), new_op.getRegion(), |
| new_op.getRegion().end()); |
| rewriter.replaceOp(op, new_op); |
| return success(); |
| }); |
| addPattern<tpu::YieldOp>( |
| patterns, [](tpu::YieldOp op, tpu::YieldOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.replaceOpWithNewOp<llo::YieldOp>(op, subst.getOperands()); |
| return success(); |
| }); |
| addPattern<tpu::MaskCastOp>( |
| patterns, [this](tpu::MaskCastOp op, tpu::MaskCastOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| if (!isMaskVectorType(op.getInput().getType()) || |
| !isMaskVectorType(op.getType())) { |
| return failure(); |
| } |
| rewriter.replaceOp(op, subst.getInput()); |
| return success(); |
| }); |
| addPattern<tpu::AssumeMultipleOp>( |
| patterns, |
| [](tpu::AssumeMultipleOp op, tpu::AssumeMultipleOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.replaceOp(op, subst.getValue()); |
| return success(); |
| }); |
| addPattern<tpu::GetIterationBoundOp>( |
| patterns, |
| [](tpu::GetIterationBoundOp op, tpu::GetIterationBoundOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.replaceOpWithNewOp<llo::GetIterationBoundOp>( |
| op, op.getDimAttr()); |
| return success(); |
| }); |
| } |
| |
| std::optional<int64_t> getMemRefSizeInBytes(MemRefType type) { |
| auto layout = type.getLayout().dyn_cast<tpu::TiledLayoutAttr>(); |
| if (!type.hasStaticShape() || !layout) { |
| return std::nullopt; |
| } |
| std::vector<int64_t> minor_to_major(type.getRank()); |
| std::iota(minor_to_major.rbegin(), minor_to_major.rend(), |
| static_cast<int64_t>(0)); |
| xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDenseLayout( |
| xla::ConvertMlirTypeToPrimitiveType(type.getElementType()), |
| type.getShape(), minor_to_major, layout.getTiles()); |
| if (!xla::jellyfish::TransferSizeUtil::HasSupportedTiling( |
| *target_->Topology(), xla_shape)) { |
| return std::nullopt; |
| } |
| return target_->ShapeSizeCompact(xla_shape); |
| } |
| |
| LogicalResult getTiledDynamicShape(SmallVector<Value> *tiled_dynamic_shape, |
| ArrayRef<int64_t> shape, |
| ArrayRef<Value> dynamic_sizes, |
| absl::Span<const int64_t> tiling, |
| ConversionPatternRewriter &rewriter, |
| Operation *op) { |
| auto dynamic_dim_it = dynamic_sizes.begin(); |
| for (int64_t i = 0; i < shape.size() - tiling.size(); ++i) { |
| Value size; |
| if (mlir::ShapedType::isDynamic(shape[i])) { |
| CHECK(dynamic_dim_it != dynamic_sizes.end()); |
| size = *dynamic_dim_it++; |
| } else { |
| size = rewriter.create<llo::ConstantOp>( |
| op->getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(shape[i])); |
| } |
| tiled_dynamic_shape->push_back(size); |
| } |
| for (int i = 0; i < tiling.size(); ++i) { |
| int64_t size = shape[shape.size() - tiling.size() + i]; |
| if (mlir::ShapedType::isDynamic(size)) { |
| return op->emitOpError("Dynamic sizes in tiled dims unsupported"); |
| } |
| tiled_dynamic_shape->push_back(rewriter.create<llo::ConstantOp>( |
| op->getLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(llvm::divideCeil(size, tiling[i])))); |
| } |
| return success(); |
| } |
| |
| Value getDynamicNumElements(ArrayRef<int64_t> shape, |
| ArrayRef<Value> dynamic_sizes, |
| ConversionPatternRewriter &rewriter) { |
| int64_t num_static_elements = 1; |
| for (int64_t size : shape) { |
| num_static_elements *= mlir::ShapedType::isDynamic(size) ? 1 : size; |
| } |
| Value num_elements = rewriter.create<llo::ConstantOp>( |
| rewriter.getUnknownLoc(), rewriter.getI32Type(), |
| rewriter.getI32IntegerAttr(num_static_elements)); |
| for (Value size : dynamic_sizes) { |
| num_elements = rewriter.create<llo::ScalarMulS32Op>( |
| rewriter.getUnknownLoc(), rewriter.getI32Type(), num_elements, size); |
| } |
| return num_elements; |
| } |
| |
| std::pair<Value, SmallVector<Value>> unpackMemRef( |
| Value ref, ConversionPatternRewriter &rewriter) { |
| if (isa<IntegerType>(ref.getType())) { |
| return std::make_pair(ref, SmallVector<Value>{}); |
| } |
| CHECK(isa<MemRefType>(ref.getType())); |
| auto ref_ty = getMemRefType(ref); |
| int64_t num_dynamic_dims = 0; |
| for (auto size : ref_ty.getShape()) { |
| if (mlir::ShapedType::isDynamic(size)) { |
| ++num_dynamic_dims; |
| } |
| } |
| Type i32 = IntegerType::get(ref.getContext(), 32); |
| SmallVector<Type> memref_rep_tys(1 + num_dynamic_dims, i32); |
| auto cast = rewriter.create<UnrealizedConversionCastOp>( |
| rewriter.getUnknownLoc(), memref_rep_tys, ref); |
| auto sizes = cast.getResults().drop_front(1); |
| return std::make_pair(cast.getResult(0), |
| SmallVector<Value>(sizes.begin(), sizes.end())); |
| } |
| |
| const xla::jellyfish::Target *target_; |
| const bool unsafe_allow_multicore_remote_dma_; |
| }; |
| |
| void LowerToLLOPass::runOnOperation() { |
| std::unique_ptr<MockTpuInstance> mock_instance; |
| if (target_ == nullptr) { |
| mock_instance = tryCreatingMock(); |
| if (!mock_instance) { |
| signalPassFailure(); |
| return; |
| } |
| target_ = mock_instance->target(); |
| } else { |
| CHECK(mock_target == -1) |
| << "mock_instance can only be used when target is unspecified"; |
| } |
| TypeConverter converter; |
| converter.addConversion( |
| [](MemRefType ty, |
| mlir::SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> { |
| int64_t dynamic_dims = 0; |
| for (auto size : ty.getShape()) { |
| dynamic_dims += mlir::ShapedType::isDynamic(size); |
| } |
| types.append(1 + dynamic_dims, IntegerType::get(ty.getContext(), 32)); |
| return success(); |
| }); |
| converter.addConversion([](tpu::SemaphoreType ty) -> std::optional<Type> { |
| return IntegerType::get(ty.getContext(), 32); |
| }); |
| converter.addConversion([](tpu::DMASemaphoreType ty) -> std::optional<Type> { |
| return IntegerType::get(ty.getContext(), 32); |
| }); |
| converter.addConversion([](IntegerType ty) -> std::optional<Type> { |
| if (ty.getWidth() > 32 || ty.getSignedness() != IntegerType::Signless) { |
| return std::nullopt; |
| } |
| return IntegerType::get(ty.getContext(), 32); |
| }); |
| converter.addConversion([](IndexType ty) -> std::optional<Type> { |
| return IntegerType::get(ty.getContext(), 32); |
| }); |
| converter.addConversion( |
| [](Float32Type ty) -> std::optional<Type> { return ty; }); |
| converter.addConversion([this](VectorType ty) -> std::optional<Type> { |
| if (!isRepresentableVectorType(ty)) { |
| return std::nullopt; |
| } |
| return ty; |
| }); |
| ConversionTarget target(getContext()); |
| target.addLegalDialect<mlir::llo::LLODialect>(); |
| target.addLegalOp<func::FuncOp, func::ReturnOp>(); |
| target.addDynamicallyLegalOp<scf::IfOp>( |
| [&](scf::IfOp op) { return converter.isLegal(op.getResultTypes()); }); |
| target.addDynamicallyLegalOp<scf::ForOp>([&](scf::ForOp op) { |
| // We check operand types, because we want to verify the type of loop bounds |
| // doesn't include the index type. |
| return converter.isLegal(op.getOperandTypes()); |
| }); |
| target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) { |
| return op->getParentOp() && isa<scf::IfOp, scf::ForOp>(op->getParentOp()) && |
| converter.isLegal(op.getOperandTypes()); |
| }); |
| target.addDynamicallyLegalOp<func::FuncOp>([](func::FuncOp op) { |
| for (Value arg : op.getArguments()) { |
| if (!arg.getType().isSignlessInteger(32)) { |
| return false; |
| } |
| } |
| return true; |
| }); |
| RewritePatternSet patterns(&getContext()); |
| populateVectorToLLOConversionPatterns(patterns); |
| populateArithToLLOConversionPatterns(patterns); |
| populateMathToLLOConversionPatterns(patterns); |
| populateMemrefToLLOConversionPatterns(patterns); |
| populateTPUToLLOConversionPatterns(patterns); |
| populateCFToLLOConversionPatterns(patterns); |
| addPattern<UnrealizedConversionCastOp>( |
| patterns, |
| [](UnrealizedConversionCastOp op, UnrealizedConversionCastOpAdaptor subst, |
| ConversionPatternRewriter &rewriter) { |
| rewriter.eraseOp(op); |
| return success(); |
| }); |
| populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, |
| converter); |
| scf::populateSCFStructuralTypeConversions(converter, patterns); |
| // Preserve the expected layouts before we proceed with the conversion. |
| func::FuncOp func = getOperation(); |
| for (int i = 0; i < func.getNumArguments(); ++i) { |
| auto arg_ty = func.getArgument(i).getType(); |
| // This is needed to allow for programs written in LLO directly. |
| if (!func.getArgAttr(i, "llo.type")) { |
| func.setArgAttr(i, "llo.type", TypeAttr::get(arg_ty)); |
| } |
| auto memref_ty = dyn_cast<MemRefType>(arg_ty); |
| if (!memref_ty) { |
| continue; |
| } |
| auto layout_attr = dyn_cast<tpu::TiledLayoutAttr>(memref_ty.getLayout()); |
| if (!layout_attr) { |
| func.emitOpError( |
| "All memref arguments should use the TiledLayoutAttr for layout"); |
| signalPassFailure(); |
| return; |
| } |
| func.setArgAttr(i, "llo.layout", layout_attr); |
| } |
| if (failed(applyFullConversion(func, target, std::move(patterns)))) { |
| signalPassFailure(); |
| } |
| if (mock_instance) { |
| target_ = nullptr; // Avoid leaving a dangling pointer. |
| } |
| } |
| |
| } // namespace |
| |
| std::unique_ptr<OperationPass<func::FuncOp>> createLowerToLLOPass( |
| const xla::jellyfish::Target &target, |
| bool unsafe_allow_multicore_remote_dma) { |
| return std::make_unique<LowerToLLOPass>(&target, |
| unsafe_allow_multicore_remote_dma); |
| } |
| |
| std::unique_ptr<OperationPass<func::FuncOp>> createPartialLowerToLLOPass() { |
| return std::make_unique<LowerToLLOPass>(); |
| } |
| |
| } // namespace mlir::tpu |