| #include "platforms/xla/mosaic/dialect/llo/llo_builder.h" |
| |
| #include <cstdint> |
| #include <memory> |
| #include <optional> |
| #include <string_view> |
| |
| #include "platforms/xla/mosaic/dialect/llo/llo_dialect.h" |
| #include "platforms/xla/service/jellyfish/dma_strides.h" |
| #include "platforms/xla/service/jellyfish/execution_profiler.h" |
| #include "platforms/xla/service/jellyfish/execution_profiler_traceme.h" |
| #include "platforms/xla/service/jellyfish/llo_instruction.h" |
| #include "platforms/xla/service/jellyfish/llo_predicated_region.h" |
| #include "platforms/xla/service/jellyfish/llo_region.h" |
| #include "platforms/xla/service/jellyfish/llo_region_builder.h" |
| #include "platforms/xla/service/jellyfish/llo_value.h" |
| #include "platforms/xla/service/jellyfish/lowering/fusion_util.h" |
| #include "platforms/xla/service/jellyfish/lowering/net_util.h" |
| #include "platforms/xla/service/jellyfish/target.h" |
| #include "third_party/absl/log/check.h" |
| #include "third_party/absl/log/log.h" |
| #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/IR/SCF.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinAttributes.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinTypes.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h" |
| #include "third_party/tensorflow/compiler/xla/mlir/utils/type_util.h" |
| #include "third_party/tensorflow/compiler/xla/primitive_util.h" |
| #include "third_party/tensorflow/compiler/xla/shape_util.h" |
| |
| namespace mlir { |
| namespace llo { |
| |
| namespace { |
| |
| ::xla::jellyfish::LloValue* GetLloValue( |
| mlir::Value val, |
| const DenseMap<mlir::Value, ::xla::jellyfish::LloValue*>& value_map) { |
| auto it = value_map.find(val); |
| assert(it != value_map.end()); |
| return it->second; |
| } |
| |
| template <typename Op> |
| ::xla::jellyfish::LloRegionBuilder::CoreLocation GetCoreLocation( |
| BuilderContext* ctx, ::xla::jellyfish::LloRegionBuilder& b, Op op, |
| std::string_view tag, |
| const DenseMap<mlir::Value, ::xla::jellyfish::LloValue*>& value_map) { |
| ::xla::jellyfish::LloRegionBuilder::CoreLocation location{ |
| .chip_id = GetLloValue(op.getChipId(), value_map), |
| .core_index = GetLloValue(op.getCoreIndex(), value_map), |
| }; |
| if (!b.target().HasLimitedIciRouting()) { |
| return location; |
| } |
| return ::xla::jellyfish::net_util::GetLimitedIciRoutingTableIndex( |
| b, location, tag, ctx->program_shared_registry, |
| /*use_routing_table_indices=*/false, /*trap_if_invalid=*/false, |
| /*must_return_index=*/true); |
| } |
| |
| template <typename Op> |
| ::xla::jellyfish::LloRegionBuilder::CoreLocation GetOptionalCoreLocation( |
| BuilderContext* ctx, ::xla::jellyfish::LloRegionBuilder& b, Op op, |
| std::string_view tag, |
| const DenseMap<mlir::Value, ::xla::jellyfish::LloValue*>& value_map) { |
| if (!op.getChipId() && !op.getCoreIndex()) { |
| return {}; |
| } |
| if (op.getChipId() && op.getCoreIndex()) { |
| return GetCoreLocation(ctx, b, op, tag, value_map); |
| } |
| LOG(FATAL) |
| << "Either both or none of core_index and chip_id should be specified"; |
| } |
| |
| ::xla::jellyfish::DmaStrides<::xla::jellyfish::LloMemUnit> GetDmaStrides( |
| xla::jellyfish::LloRegionBuilder& b, ArrayRef<int32_t> src_strides, |
| ArrayRef<int32_t> dst_strides, OperandRange steps_per_stride, |
| const DenseMap<mlir::Value, ::xla::jellyfish::LloValue*>& value_map) { |
| using xla::jellyfish::Granule; |
| using xla::jellyfish::LloMemUnit; |
| CHECK_EQ(src_strides.size(), dst_strides.size()); |
| CHECK_EQ(src_strides.size(), steps_per_stride.size()); |
| xla::jellyfish::DmaStrides<LloMemUnit> result; |
| for (int i = 0; i < src_strides.size(); ++i) { |
| LloMemUnit steps; |
| if (i == src_strides.size() - 1) { |
| steps = |
| b.LloMemUnitFromBytes(GetLloValue(steps_per_stride[i], value_map)); |
| } else { |
| steps = LloMemUnit(GetLloValue(steps_per_stride[i], value_map), |
| Granule::kInvalid); |
| } |
| result.emplace_back(b.LloMemUnitFromBytes(b.SimmS32(src_strides[i])), |
| b.LloMemUnitFromBytes(b.SimmS32(dst_strides[i])), |
| steps); |
| } |
| return result; |
| } |
| |
| ::xla::PrimitiveType ToPrimitiveType(mlir::Type type, |
| bool throw_error_on_signless = true) { |
| if (type.isSignlessInteger() && throw_error_on_signless) { |
| LOG(FATAL) << "Signedness cannot be inferred"; |
| } |
| return ::xla::ConvertMlirTypeToPrimitiveType(type); |
| } |
| |
| #include "platforms/xla/mosaic/dialect/llo/llo_builder.cc.inc" |
| |
| } // namespace |
| |
| LogicalResult AppendBlock( |
| BuilderContext* ctx, ::xla::jellyfish::LloRegionBuilder& b, Block& block, |
| DenseMap<Value, ::xla::jellyfish::LloValue*>& value_map) { |
| for (mlir::Operation& any_op : block.without_terminator()) { |
| if (AppendInstruction(ctx, b, &any_op, value_map).failed()) { |
| return failure(); |
| } |
| } |
| return success(); |
| } |
| |
| LogicalResult AppendInstruction( |
| BuilderContext* ctx, ::xla::jellyfish::LloRegionBuilder& b, |
| Operation* raw_op, |
| DenseMap<Value, ::xla::jellyfish::LloValue*>& value_map) { |
| if (auto op = dyn_cast<llo::ConstantOp>(raw_op)) { |
| ::xla::jellyfish::LloValue* value = nullptr; |
| if (auto attr = dyn_cast<IntegerAttr>(op.getValue())) { |
| if (attr.getType().isUnsignedInteger(32)) { |
| value = b.SimmU32(attr.getUInt()); |
| } else if (attr.getType().isSignlessInteger(32)) { |
| value = b.SimmS32(attr.getInt()); |
| } else if (attr.getType().isSignlessInteger(1)) { |
| value = b.Pimm(attr.getInt()); |
| } |
| } else if (auto attr = dyn_cast<FloatAttr>(op.getValue())) { |
| if (attr.getType().isF32()) { |
| value = b.SimmF32(attr.getValueAsDouble()); |
| } |
| } else if (auto attr = dyn_cast<SplatElementsAttr>(op.getValue())) { |
| mlir::ShapedType ty = attr.getType(); |
| int64_t sublanes = b.target().SublaneCount(); |
| int64_t lanes = b.target().LaneCount(); |
| if (ty.getShape() == ArrayRef<int64_t>{sublanes, lanes}) { |
| if (attr.getElementType().isF32()) { |
| value = b.VimmF32(attr.getSplatValue<float>()); |
| } else if (attr.getElementType().isSignlessInteger(32)) { |
| value = b.VimmS32(attr.getSplatValue<int32_t>()); |
| } else if (attr.getElementType().isSignlessInteger(1)) { |
| // TODO(b/272487785): The shape check above might require adjustment. |
| value = b.Vmimm(attr.getSplatValue<bool>()); |
| } |
| } else if (ty.getShape() == ArrayRef<int64_t>{sublanes, lanes, 2}) { |
| if (attr.getElementType().isBF16()) { |
| uint16_t bits = attr.bitcast(IntegerType::get(op.getContext(), 16, |
| IntegerType::Unsigned)) |
| .getSplatValue<uint16_t>(); |
| value = b.VimmU32(bits | (bits << 16)); |
| } else if (attr.getElementType().isSignlessInteger(16)) { |
| uint16_t bits = attr.getSplatValue<uint16_t>(); |
| value = b.VimmU32(bits | (bits << 16)); |
| } else if (b.target().BitsPerVmregLaneAndSublane() >= 2 && |
| attr.getElementType().isSignlessInteger(1)) { |
| value = b.Vmimm(attr.getSplatValue<bool>()); |
| } |
| } else if (ty.getShape() == ArrayRef<int64_t>{sublanes, lanes, 4}) { |
| if (b.target().BitsPerVmregLaneAndSublane() >= 4 && |
| attr.getElementType().isSignlessInteger(1)) { |
| value = b.Vmimm(attr.getSplatValue<bool>()); |
| } else if (attr.getElementType().isSignlessInteger(8)) { |
| uint8_t bits = attr.getSplatValue<uint8_t>(); |
| value = b.VimmU32(bits | (bits << 8) | (bits << 16) | (bits << 24)); |
| } |
| } |
| } |
| if (value != nullptr) { |
| value_map[op.getResult()] = value; |
| return success(); |
| } |
| return failure(); |
| } |
| if (auto op = dyn_cast<llo::VectorBitcastOp>(raw_op)) { |
| value_map[op.getOutput()] = GetLloValue(op.getInput(), value_map); |
| return success(); |
| } |
| if (auto op = dyn_cast<llo::ScalarBitcastOp>(raw_op)) { |
| value_map[op.getOutput()] = GetLloValue(op.getInput(), value_map); |
| return success(); |
| } |
| if (auto op = dyn_cast<mlir::scf::IfOp>(raw_op)) { |
| auto* pred = GetLloValue(op.getCondition(), value_map); |
| if (!op.getThenRegion().hasOneBlock()) { |
| return failure(); |
| } |
| auto& then_block = op.getThenRegion().getBlocks().front(); |
| if (!then_block.empty()) { |
| Operation* then_terminator = then_block.getTerminator(); |
| if (then_terminator->getNumOperands() != op.getNumResults()) { |
| return failure(); |
| } |
| for (int i = 0; i < op.getNumResults(); ++i) { |
| auto val = then_terminator->getOperand(i); |
| if (value_map.find(val) == value_map.end()) { |
| value_map[op.getResult(i)] = |
| val.getType().isa<VectorType>() ? b.VimmS32(0) : b.SimmS32(0); |
| } |
| } |
| xla::jellyfish::LloPredicatedRegion* subregion = b.Predicated(pred); |
| auto subbuilder = subregion->region_builder(); |
| if (AppendBlock(ctx, subbuilder, then_block, value_map).failed()) { |
| return failure(); |
| } |
| for (int i = 0; i < op.getNumResults(); ++i) { |
| auto& result = value_map[op.getResult(i)]; |
| auto* then_result = |
| GetLloValue(then_terminator->getOperand(i), value_map); |
| result = (result == nullptr) ? then_result |
| : subregion->fallthru_region_builder().Phi( |
| result, then_result); |
| } |
| } |
| if (op.getElseRegion().empty()) { |
| return success(); |
| } |
| if (!op.getElseRegion().hasOneBlock()) { |
| return failure(); |
| } |
| auto& else_block = op.getElseRegion().getBlocks().front(); |
| if (!else_block.empty()) { |
| xla::jellyfish::LloValue* neg_pred = b.Pneg(pred); |
| xla::jellyfish::LloPredicatedRegion* subregion = b.Predicated(neg_pred); |
| auto subbuilder = subregion->region_builder(); |
| if (AppendBlock(ctx, subbuilder, else_block, value_map).failed()) { |
| return failure(); |
| } |
| Operation* else_terminator = else_block.getTerminator(); |
| if (else_terminator->getNumOperands() != op.getNumResults()) { |
| return failure(); |
| } |
| for (int i = 0; i < op.getNumResults(); ++i) { |
| auto* result = GetLloValue(op.getResult(i), value_map); |
| auto* else_value = |
| GetLloValue(else_terminator->getOperand(i), value_map); |
| value_map[op.getResult(i)] = |
| subregion->fallthru_region_builder().Phi(result, else_value); |
| } |
| } |
| |
| return success(); |
| } |
| if (auto op = dyn_cast<mlir::scf::ForOp>(raw_op)) { |
| auto* lbd = GetLloValue(op.getLowerBound(), value_map); |
| auto* ubd = GetLloValue(op.getUpperBound(), value_map); |
| auto* step = GetLloValue(op.getStep(), value_map); |
| auto [loop, induction_variable] = |
| b.Loop(b.region()->hlo_instruction(), lbd, ubd, step); |
| value_map[op.getInductionVar()] = induction_variable; |
| // Add iter args to value_map. |
| for (auto [arg, init] : |
| llvm::zip_equal(op.getRegionIterArgs(), op.getInitArgs())) { |
| value_map[arg] = loop->header_builder().Phi(GetLloValue(init, value_map)); |
| } |
| CHECK(op.getRegion().hasOneBlock()); |
| xla::jellyfish::LloRegionBuilder body_builder(loop->body_builder()); |
| if (AppendBlock(ctx, body_builder, *op.getBody(), value_map).failed()) { |
| return op.emitOpError("Failed to build llo loop body in scf.for."); |
| } |
| // Update iter args after each loop and finally map result to iter arg. |
| Operation* terminator = op.getBody()->getTerminator(); |
| CHECK_EQ(terminator->getNumOperands(), op.getNumRegionIterArgs()); |
| for (int i = 0; i < op.getNumRegionIterArgs(); ++i) { |
| value_map[op.getRegionIterArg(i)]->PhiAppend( |
| GetLloValue(terminator->getOperand(i), value_map)); |
| value_map[op.getResult(i)] = |
| GetLloValue(op.getRegionIterArg(i), value_map); |
| } |
| return success(); |
| } |
| if (auto op = dyn_cast<llo::GetBarrierSyncFlagOp>(raw_op)) { |
| if (!ctx->custom_barrier_sync_flag) { |
| return failure(); |
| } |
| value_map[op.getResult()] = ctx->custom_barrier_sync_flag; |
| return success(); |
| } |
| if (auto op = dyn_cast<llo::GetIterationBoundOp>(raw_op)) { |
| if (op.getDim() >= ctx->dynamic_iteration_bounds.size()) { |
| return failure(); |
| } |
| value_map[op.getResult()] = ctx->dynamic_iteration_bounds[op.getDim()]; |
| return success(); |
| } |
| if (auto op = dyn_cast<llo::RegionOp>(raw_op)) { |
| xla::jellyfish::LloRegionBuilder subbuilder( |
| b.region()->AddRegion(b.region()->hlo_instruction())); |
| if (!op.getRegion().hasOneBlock()) { |
| return failure(); |
| } |
| Block& block = op.getRegion().front(); |
| if (AppendBlock(ctx, subbuilder, block, value_map).failed()) { |
| return failure(); |
| } |
| Operation* terminator = block.getTerminator(); |
| if (terminator->getNumOperands() != op.getNumResults()) { |
| return failure(); |
| } |
| for (int i = 0; i < op.getNumResults(); ++i) { |
| value_map[op.getResult(i)] = |
| GetLloValue(terminator->getOperand(i), value_map); |
| } |
| return success(); |
| } |
| if (auto op = dyn_cast<llo::TraceOp>(raw_op)) { |
| std::optional<xla::jellyfish::OnDeviceTraceMe> trace_guard; |
| if (xla::jellyfish::ExecutionProfiler* profiler = b.module()->profiler()) { |
| auto instr_id = |
| profiler->AllocateLloInstrumentationId(op.getMessage().str()); |
| if (!instr_id.ok()) { |
| return failure(); |
| } |
| trace_guard.emplace(b.region(), instr_id.value(), op.getLevel()); |
| } |
| if (!op.getRegion().hasOneBlock()) { |
| return failure(); |
| } |
| Block& block = op.getRegion().front(); |
| if (AppendBlock(ctx, b, block, value_map).failed()) { |
| return failure(); |
| } |
| Operation* terminator = block.getTerminator(); |
| if (terminator->getNumOperands() != op.getNumResults()) { |
| return failure(); |
| } |
| for (int i = 0; i < op.getNumResults(); ++i) { |
| value_map[op.getResult(i)] = |
| GetLloValue(terminator->getOperand(i), value_map); |
| } |
| return success(); |
| } |
| if (auto op = dyn_cast<llo::TraceStartOp>(raw_op)) { |
| if (xla::jellyfish::ExecutionProfiler* profiler = b.module()->profiler()) { |
| auto instr_id = |
| profiler->AllocateLloInstrumentationId(op.getMessage().str()); |
| if (!instr_id.ok()) { |
| return failure(); |
| } |
| ctx->traceme_stack.push_back( |
| std::make_unique<xla::jellyfish::OnDeviceTraceMe>( |
| b.region(), instr_id.value(), op.getLevel())); |
| } |
| return success(); |
| } |
| if (auto op = dyn_cast<llo::TraceStopOp>(raw_op)) { |
| if (b.module()->profiler()) { |
| ctx->traceme_stack.pop_back(); |
| } |
| return success(); |
| } |
| if (auto op = dyn_cast<llo::LogicalDeviceIdOp>(raw_op)) { |
| const auto& module_config = b.hlo()->GetModule()->config(); |
| if (!module_config.has_static_device_assignment()) { |
| return failure(); |
| } |
| const auto& device_assignment = module_config.static_device_assignment(); |
| xla::jellyfish::LloValue* partition_id = |
| xla::jellyfish::net_util::GetPartitionId(b); |
| xla::jellyfish::LloValue* replica_id = |
| xla::jellyfish::net_util::GetReplicaId(b); |
| xla::jellyfish::LloValue* global_id = b.SaddS32( |
| b.SmulS32(replica_id, b.SimmS32(device_assignment.computation_count())), |
| partition_id); |
| value_map[op.getResult()] = global_id; |
| return success(); |
| } |
| if (auto op = dyn_cast<llo::ScalarDivS32Op>(raw_op)) { |
| if (auto crhs = op.getRhs().getDefiningOp<llo::ConstantOp>()) { |
| auto divisor = cast<mlir::IntegerAttr>(crhs.getValue()).getInt(); |
| if (divisor == 0) { |
| goto generic_div; |
| } |
| value_map[op.getResult()] = |
| b.SdivS32(GetLloValue(op.getLhs(), value_map), divisor); |
| return success(); |
| } |
| generic_div: |
| value_map[op.getResult()] = |
| b.SdivS32General(GetLloValue(op.getLhs(), value_map), |
| GetLloValue(op.getRhs(), value_map)); |
| return success(); |
| } |
| if (auto op = dyn_cast<llo::ScalarRemS32Op>(raw_op)) { |
| if (auto crhs = op.getRhs().getDefiningOp<llo::ConstantOp>()) { |
| auto divisor = cast<mlir::IntegerAttr>(crhs.getValue()).getInt(); |
| if (divisor == 0) { |
| goto generic_rem; |
| } |
| value_map[op.getResult()] = |
| b.SremS32(GetLloValue(op.getLhs(), value_map), divisor); |
| return success(); |
| } |
| generic_rem: |
| value_map[op.getResult()] = |
| b.SremS32General(GetLloValue(op.getLhs(), value_map), |
| GetLloValue(op.getRhs(), value_map)); |
| return success(); |
| } |
| if (auto op = dyn_cast<llo::LogS32>(raw_op)) { |
| b.LogS32(GetLloValue(op.getOperand(), value_map), op.getTag().str()); |
| return success(); |
| } |
| if (auto op = dyn_cast<llo::LogPred>(raw_op)) { |
| b.LogPred(GetLloValue(op.getOperand(), value_map), op.getTag().str()); |
| return success(); |
| } |
| if (auto op = dyn_cast<llo::AllocaVmemOp>(raw_op)) { |
| int64_t num_words = op.getNumWords(); |
| int64_t word_size = b.target().VmemWordSizeBytes(); |
| auto addr = b.AllocateScopedVmem( |
| xla::ShapeUtil::MakeShape(::xla::U8, {num_words * word_size})); |
| value_map[op.getResult()] = addr; |
| return success(); |
| } |
| if (auto op = dyn_cast<llo::AllocaSmemOp>(raw_op)) { |
| int64_t num_words = op.getNumWords(); |
| int64_t word_size = b.target().SmemWordSizeBytes(); |
| auto addr = b.AllocateScopedSmem( |
| xla::ShapeUtil::MakeShape(::xla::U8, {num_words * word_size})); |
| value_map[op.getResult()] = addr; |
| return success(); |
| } |
| if (AppendInstructionGenerated(ctx, b, raw_op, value_map).succeeded()) { |
| return success(); |
| } |
| raw_op->emitOpError("Failed to translate operation"); |
| return failure(); |
| } |
| |
| } // namespace llo |
| } // namespace mlir |