blob: 6319c559d01ad76dfe461655ef4e0d374754d9bb [file] [log] [blame]
#include "platforms/xla/mosaic/dialect/llo/llo_builder.h"
#include <cstdint>
#include <memory>
#include <optional>
#include <string_view>
#include "platforms/xla/mosaic/dialect/llo/llo_dialect.h"
#include "platforms/xla/service/jellyfish/dma_strides.h"
#include "platforms/xla/service/jellyfish/execution_profiler.h"
#include "platforms/xla/service/jellyfish/execution_profiler_traceme.h"
#include "platforms/xla/service/jellyfish/llo_instruction.h"
#include "platforms/xla/service/jellyfish/llo_predicated_region.h"
#include "platforms/xla/service/jellyfish/llo_region.h"
#include "platforms/xla/service/jellyfish/llo_region_builder.h"
#include "platforms/xla/service/jellyfish/llo_value.h"
#include "platforms/xla/service/jellyfish/lowering/fusion_util.h"
#include "platforms/xla/service/jellyfish/lowering/net_util.h"
#include "platforms/xla/service/jellyfish/target.h"
#include "third_party/absl/log/check.h"
#include "third_party/absl/log/log.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/STLExtras.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Dialect/SCF/IR/SCF.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinAttributes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/BuiltinTypes.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Types.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Value.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LLVM.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h"
#include "third_party/tensorflow/compiler/xla/mlir/utils/type_util.h"
#include "third_party/tensorflow/compiler/xla/primitive_util.h"
#include "third_party/tensorflow/compiler/xla/shape_util.h"
namespace mlir {
namespace llo {
namespace {
::xla::jellyfish::LloValue* GetLloValue(
mlir::Value val,
const DenseMap<mlir::Value, ::xla::jellyfish::LloValue*>& value_map) {
auto it = value_map.find(val);
assert(it != value_map.end());
return it->second;
}
template <typename Op>
::xla::jellyfish::LloRegionBuilder::CoreLocation GetCoreLocation(
BuilderContext* ctx, ::xla::jellyfish::LloRegionBuilder& b, Op op,
std::string_view tag,
const DenseMap<mlir::Value, ::xla::jellyfish::LloValue*>& value_map) {
::xla::jellyfish::LloRegionBuilder::CoreLocation location{
.chip_id = GetLloValue(op.getChipId(), value_map),
.core_index = GetLloValue(op.getCoreIndex(), value_map),
};
if (!b.target().HasLimitedIciRouting()) {
return location;
}
return ::xla::jellyfish::net_util::GetLimitedIciRoutingTableIndex(
b, location, tag, ctx->program_shared_registry,
/*use_routing_table_indices=*/false, /*trap_if_invalid=*/false,
/*must_return_index=*/true);
}
template <typename Op>
::xla::jellyfish::LloRegionBuilder::CoreLocation GetOptionalCoreLocation(
BuilderContext* ctx, ::xla::jellyfish::LloRegionBuilder& b, Op op,
std::string_view tag,
const DenseMap<mlir::Value, ::xla::jellyfish::LloValue*>& value_map) {
if (!op.getChipId() && !op.getCoreIndex()) {
return {};
}
if (op.getChipId() && op.getCoreIndex()) {
return GetCoreLocation(ctx, b, op, tag, value_map);
}
LOG(FATAL)
<< "Either both or none of core_index and chip_id should be specified";
}
::xla::jellyfish::DmaStrides<::xla::jellyfish::LloMemUnit> GetDmaStrides(
xla::jellyfish::LloRegionBuilder& b, ArrayRef<int32_t> src_strides,
ArrayRef<int32_t> dst_strides, OperandRange steps_per_stride,
const DenseMap<mlir::Value, ::xla::jellyfish::LloValue*>& value_map) {
using xla::jellyfish::Granule;
using xla::jellyfish::LloMemUnit;
CHECK_EQ(src_strides.size(), dst_strides.size());
CHECK_EQ(src_strides.size(), steps_per_stride.size());
xla::jellyfish::DmaStrides<LloMemUnit> result;
for (int i = 0; i < src_strides.size(); ++i) {
LloMemUnit steps;
if (i == src_strides.size() - 1) {
steps =
b.LloMemUnitFromBytes(GetLloValue(steps_per_stride[i], value_map));
} else {
steps = LloMemUnit(GetLloValue(steps_per_stride[i], value_map),
Granule::kInvalid);
}
result.emplace_back(b.LloMemUnitFromBytes(b.SimmS32(src_strides[i])),
b.LloMemUnitFromBytes(b.SimmS32(dst_strides[i])),
steps);
}
return result;
}
::xla::PrimitiveType ToPrimitiveType(mlir::Type type,
bool throw_error_on_signless = true) {
if (type.isSignlessInteger() && throw_error_on_signless) {
LOG(FATAL) << "Signedness cannot be inferred";
}
return ::xla::ConvertMlirTypeToPrimitiveType(type);
}
#include "platforms/xla/mosaic/dialect/llo/llo_builder.cc.inc"
} // namespace
LogicalResult AppendBlock(
BuilderContext* ctx, ::xla::jellyfish::LloRegionBuilder& b, Block& block,
DenseMap<Value, ::xla::jellyfish::LloValue*>& value_map) {
for (mlir::Operation& any_op : block.without_terminator()) {
if (AppendInstruction(ctx, b, &any_op, value_map).failed()) {
return failure();
}
}
return success();
}
LogicalResult AppendInstruction(
BuilderContext* ctx, ::xla::jellyfish::LloRegionBuilder& b,
Operation* raw_op,
DenseMap<Value, ::xla::jellyfish::LloValue*>& value_map) {
if (auto op = dyn_cast<llo::ConstantOp>(raw_op)) {
::xla::jellyfish::LloValue* value = nullptr;
if (auto attr = dyn_cast<IntegerAttr>(op.getValue())) {
if (attr.getType().isUnsignedInteger(32)) {
value = b.SimmU32(attr.getUInt());
} else if (attr.getType().isSignlessInteger(32)) {
value = b.SimmS32(attr.getInt());
} else if (attr.getType().isSignlessInteger(1)) {
value = b.Pimm(attr.getInt());
}
} else if (auto attr = dyn_cast<FloatAttr>(op.getValue())) {
if (attr.getType().isF32()) {
value = b.SimmF32(attr.getValueAsDouble());
}
} else if (auto attr = dyn_cast<SplatElementsAttr>(op.getValue())) {
mlir::ShapedType ty = attr.getType();
int64_t sublanes = b.target().SublaneCount();
int64_t lanes = b.target().LaneCount();
if (ty.getShape() == ArrayRef<int64_t>{sublanes, lanes}) {
if (attr.getElementType().isF32()) {
value = b.VimmF32(attr.getSplatValue<float>());
} else if (attr.getElementType().isSignlessInteger(32)) {
value = b.VimmS32(attr.getSplatValue<int32_t>());
} else if (attr.getElementType().isSignlessInteger(1)) {
// TODO(b/272487785): The shape check above might require adjustment.
value = b.Vmimm(attr.getSplatValue<bool>());
}
} else if (ty.getShape() == ArrayRef<int64_t>{sublanes, lanes, 2}) {
if (attr.getElementType().isBF16()) {
uint16_t bits = attr.bitcast(IntegerType::get(op.getContext(), 16,
IntegerType::Unsigned))
.getSplatValue<uint16_t>();
value = b.VimmU32(bits | (bits << 16));
} else if (attr.getElementType().isSignlessInteger(16)) {
uint16_t bits = attr.getSplatValue<uint16_t>();
value = b.VimmU32(bits | (bits << 16));
} else if (b.target().BitsPerVmregLaneAndSublane() >= 2 &&
attr.getElementType().isSignlessInteger(1)) {
value = b.Vmimm(attr.getSplatValue<bool>());
}
} else if (ty.getShape() == ArrayRef<int64_t>{sublanes, lanes, 4}) {
if (b.target().BitsPerVmregLaneAndSublane() >= 4 &&
attr.getElementType().isSignlessInteger(1)) {
value = b.Vmimm(attr.getSplatValue<bool>());
} else if (attr.getElementType().isSignlessInteger(8)) {
uint8_t bits = attr.getSplatValue<uint8_t>();
value = b.VimmU32(bits | (bits << 8) | (bits << 16) | (bits << 24));
}
}
}
if (value != nullptr) {
value_map[op.getResult()] = value;
return success();
}
return failure();
}
if (auto op = dyn_cast<llo::VectorBitcastOp>(raw_op)) {
value_map[op.getOutput()] = GetLloValue(op.getInput(), value_map);
return success();
}
if (auto op = dyn_cast<llo::ScalarBitcastOp>(raw_op)) {
value_map[op.getOutput()] = GetLloValue(op.getInput(), value_map);
return success();
}
if (auto op = dyn_cast<mlir::scf::IfOp>(raw_op)) {
auto* pred = GetLloValue(op.getCondition(), value_map);
if (!op.getThenRegion().hasOneBlock()) {
return failure();
}
auto& then_block = op.getThenRegion().getBlocks().front();
if (!then_block.empty()) {
Operation* then_terminator = then_block.getTerminator();
if (then_terminator->getNumOperands() != op.getNumResults()) {
return failure();
}
for (int i = 0; i < op.getNumResults(); ++i) {
auto val = then_terminator->getOperand(i);
if (value_map.find(val) == value_map.end()) {
value_map[op.getResult(i)] =
val.getType().isa<VectorType>() ? b.VimmS32(0) : b.SimmS32(0);
}
}
xla::jellyfish::LloPredicatedRegion* subregion = b.Predicated(pred);
auto subbuilder = subregion->region_builder();
if (AppendBlock(ctx, subbuilder, then_block, value_map).failed()) {
return failure();
}
for (int i = 0; i < op.getNumResults(); ++i) {
auto& result = value_map[op.getResult(i)];
auto* then_result =
GetLloValue(then_terminator->getOperand(i), value_map);
result = (result == nullptr) ? then_result
: subregion->fallthru_region_builder().Phi(
result, then_result);
}
}
if (op.getElseRegion().empty()) {
return success();
}
if (!op.getElseRegion().hasOneBlock()) {
return failure();
}
auto& else_block = op.getElseRegion().getBlocks().front();
if (!else_block.empty()) {
xla::jellyfish::LloValue* neg_pred = b.Pneg(pred);
xla::jellyfish::LloPredicatedRegion* subregion = b.Predicated(neg_pred);
auto subbuilder = subregion->region_builder();
if (AppendBlock(ctx, subbuilder, else_block, value_map).failed()) {
return failure();
}
Operation* else_terminator = else_block.getTerminator();
if (else_terminator->getNumOperands() != op.getNumResults()) {
return failure();
}
for (int i = 0; i < op.getNumResults(); ++i) {
auto* result = GetLloValue(op.getResult(i), value_map);
auto* else_value =
GetLloValue(else_terminator->getOperand(i), value_map);
value_map[op.getResult(i)] =
subregion->fallthru_region_builder().Phi(result, else_value);
}
}
return success();
}
if (auto op = dyn_cast<mlir::scf::ForOp>(raw_op)) {
auto* lbd = GetLloValue(op.getLowerBound(), value_map);
auto* ubd = GetLloValue(op.getUpperBound(), value_map);
auto* step = GetLloValue(op.getStep(), value_map);
auto [loop, induction_variable] =
b.Loop(b.region()->hlo_instruction(), lbd, ubd, step);
value_map[op.getInductionVar()] = induction_variable;
// Add iter args to value_map.
for (auto [arg, init] :
llvm::zip_equal(op.getRegionIterArgs(), op.getInitArgs())) {
value_map[arg] = loop->header_builder().Phi(GetLloValue(init, value_map));
}
CHECK(op.getRegion().hasOneBlock());
xla::jellyfish::LloRegionBuilder body_builder(loop->body_builder());
if (AppendBlock(ctx, body_builder, *op.getBody(), value_map).failed()) {
return op.emitOpError("Failed to build llo loop body in scf.for.");
}
// Update iter args after each loop and finally map result to iter arg.
Operation* terminator = op.getBody()->getTerminator();
CHECK_EQ(terminator->getNumOperands(), op.getNumRegionIterArgs());
for (int i = 0; i < op.getNumRegionIterArgs(); ++i) {
value_map[op.getRegionIterArg(i)]->PhiAppend(
GetLloValue(terminator->getOperand(i), value_map));
value_map[op.getResult(i)] =
GetLloValue(op.getRegionIterArg(i), value_map);
}
return success();
}
if (auto op = dyn_cast<llo::GetBarrierSyncFlagOp>(raw_op)) {
if (!ctx->custom_barrier_sync_flag) {
return failure();
}
value_map[op.getResult()] = ctx->custom_barrier_sync_flag;
return success();
}
if (auto op = dyn_cast<llo::GetIterationBoundOp>(raw_op)) {
if (op.getDim() >= ctx->dynamic_iteration_bounds.size()) {
return failure();
}
value_map[op.getResult()] = ctx->dynamic_iteration_bounds[op.getDim()];
return success();
}
if (auto op = dyn_cast<llo::RegionOp>(raw_op)) {
xla::jellyfish::LloRegionBuilder subbuilder(
b.region()->AddRegion(b.region()->hlo_instruction()));
if (!op.getRegion().hasOneBlock()) {
return failure();
}
Block& block = op.getRegion().front();
if (AppendBlock(ctx, subbuilder, block, value_map).failed()) {
return failure();
}
Operation* terminator = block.getTerminator();
if (terminator->getNumOperands() != op.getNumResults()) {
return failure();
}
for (int i = 0; i < op.getNumResults(); ++i) {
value_map[op.getResult(i)] =
GetLloValue(terminator->getOperand(i), value_map);
}
return success();
}
if (auto op = dyn_cast<llo::TraceOp>(raw_op)) {
std::optional<xla::jellyfish::OnDeviceTraceMe> trace_guard;
if (xla::jellyfish::ExecutionProfiler* profiler = b.module()->profiler()) {
auto instr_id =
profiler->AllocateLloInstrumentationId(op.getMessage().str());
if (!instr_id.ok()) {
return failure();
}
trace_guard.emplace(b.region(), instr_id.value(), op.getLevel());
}
if (!op.getRegion().hasOneBlock()) {
return failure();
}
Block& block = op.getRegion().front();
if (AppendBlock(ctx, b, block, value_map).failed()) {
return failure();
}
Operation* terminator = block.getTerminator();
if (terminator->getNumOperands() != op.getNumResults()) {
return failure();
}
for (int i = 0; i < op.getNumResults(); ++i) {
value_map[op.getResult(i)] =
GetLloValue(terminator->getOperand(i), value_map);
}
return success();
}
if (auto op = dyn_cast<llo::TraceStartOp>(raw_op)) {
if (xla::jellyfish::ExecutionProfiler* profiler = b.module()->profiler()) {
auto instr_id =
profiler->AllocateLloInstrumentationId(op.getMessage().str());
if (!instr_id.ok()) {
return failure();
}
ctx->traceme_stack.push_back(
std::make_unique<xla::jellyfish::OnDeviceTraceMe>(
b.region(), instr_id.value(), op.getLevel()));
}
return success();
}
if (auto op = dyn_cast<llo::TraceStopOp>(raw_op)) {
if (b.module()->profiler()) {
ctx->traceme_stack.pop_back();
}
return success();
}
if (auto op = dyn_cast<llo::LogicalDeviceIdOp>(raw_op)) {
const auto& module_config = b.hlo()->GetModule()->config();
if (!module_config.has_static_device_assignment()) {
return failure();
}
const auto& device_assignment = module_config.static_device_assignment();
xla::jellyfish::LloValue* partition_id =
xla::jellyfish::net_util::GetPartitionId(b);
xla::jellyfish::LloValue* replica_id =
xla::jellyfish::net_util::GetReplicaId(b);
xla::jellyfish::LloValue* global_id = b.SaddS32(
b.SmulS32(replica_id, b.SimmS32(device_assignment.computation_count())),
partition_id);
value_map[op.getResult()] = global_id;
return success();
}
if (auto op = dyn_cast<llo::ScalarDivS32Op>(raw_op)) {
if (auto crhs = op.getRhs().getDefiningOp<llo::ConstantOp>()) {
auto divisor = cast<mlir::IntegerAttr>(crhs.getValue()).getInt();
if (divisor == 0) {
goto generic_div;
}
value_map[op.getResult()] =
b.SdivS32(GetLloValue(op.getLhs(), value_map), divisor);
return success();
}
generic_div:
value_map[op.getResult()] =
b.SdivS32General(GetLloValue(op.getLhs(), value_map),
GetLloValue(op.getRhs(), value_map));
return success();
}
if (auto op = dyn_cast<llo::ScalarRemS32Op>(raw_op)) {
if (auto crhs = op.getRhs().getDefiningOp<llo::ConstantOp>()) {
auto divisor = cast<mlir::IntegerAttr>(crhs.getValue()).getInt();
if (divisor == 0) {
goto generic_rem;
}
value_map[op.getResult()] =
b.SremS32(GetLloValue(op.getLhs(), value_map), divisor);
return success();
}
generic_rem:
value_map[op.getResult()] =
b.SremS32General(GetLloValue(op.getLhs(), value_map),
GetLloValue(op.getRhs(), value_map));
return success();
}
if (auto op = dyn_cast<llo::LogS32>(raw_op)) {
b.LogS32(GetLloValue(op.getOperand(), value_map), op.getTag().str());
return success();
}
if (auto op = dyn_cast<llo::LogPred>(raw_op)) {
b.LogPred(GetLloValue(op.getOperand(), value_map), op.getTag().str());
return success();
}
if (auto op = dyn_cast<llo::AllocaVmemOp>(raw_op)) {
int64_t num_words = op.getNumWords();
int64_t word_size = b.target().VmemWordSizeBytes();
auto addr = b.AllocateScopedVmem(
xla::ShapeUtil::MakeShape(::xla::U8, {num_words * word_size}));
value_map[op.getResult()] = addr;
return success();
}
if (auto op = dyn_cast<llo::AllocaSmemOp>(raw_op)) {
int64_t num_words = op.getNumWords();
int64_t word_size = b.target().SmemWordSizeBytes();
auto addr = b.AllocateScopedSmem(
xla::ShapeUtil::MakeShape(::xla::U8, {num_words * word_size}));
value_map[op.getResult()] = addr;
return success();
}
if (AppendInstructionGenerated(ctx, b, raw_op, value_map).succeeded()) {
return success();
}
raw_op->emitOpError("Failed to translate operation");
return failure();
}
} // namespace llo
} // namespace mlir