blob: cd3eeb527bd87456f114d38f3fce5d63ae74b710 [file] [log] [blame]
//===- TPULowerBranchPseudos.cpp - Remove BRreserve/BRs ----------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the OverPredicatePass and UnderPredicatePass.
//
// OverPredicatePass adds branch predicates to unpredicated instructions to
// allow scheduling flexibility.
// UnderPredicatePass removes unnecessary predicates from instructions after
// scheduling.
//
//===----------------------------------------------------------------------===//
#include "TPU.h"
#include "TPUSchedule.h"
#include "TPUSubtarget.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineLoopInfo.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
using namespace llvm;
extern cl::opt<bool> BrCondForTest;
namespace {
class OverPredicate : public MachineFunctionPass {
public:
static char ID;
OverPredicate() : MachineFunctionPass(ID) {}
bool runOnMachineFunction(MachineFunction &MF);
StringRef getPassName() const override {
return "TPU overpredication";
}
private:
const TargetSubtargetInfo *ST;
const TargetInstrInfo *TII;
};
char OverPredicate::ID = 0;
class UnderPredicate : public MachineFunctionPass {
public:
static char ID;
UnderPredicate() : MachineFunctionPass(ID) {}
bool runOnMachineFunction(MachineFunction &MF);
StringRef getPassName() const override {
return "TPU underpredication";
}
private:
const TargetSubtargetInfo *ST;
};
char UnderPredicate::ID = 0;
} // namespace
INITIALIZE_PASS(OverPredicate, "tpu-over-predicate",
"TPU overpredication", false, false)
INITIALIZE_PASS(UnderPredicate, "tpu-under-predicate",
"TPU underpredication", false, false)
bool OverPredicate::runOnMachineFunction(MachineFunction &MF) {
ST = &MF.getSubtarget();
TII = ST->getInstrInfo();
if (BrCondForTest) {
for (auto &MBB : MF) {
for (auto &MI : MBB) {
if (MI.getOpcode() == TPU::BRcondT)
MI.setDesc(TII->get(TPU::BRcond));
}
}
}
bool IsTensorCore = MF.getSubtarget<TPUSubtarget>().hasV1024();
// Because we have multiple branch opcodes based on the fact that it is a
// conditional branch or not we need an helper function that changes the
// branch type when we want to apply a predicate to a branch.
auto ApplyPredicateToInstr = [this](MachineInstr &I, const TPUPredicate &P) {
if (!I.isBranch())
P.applyTo(&I);
const unsigned Opcode = I.getOpcode();
// If it is a straight BR then it has already the correct predicate.
if (P.isAlways() && Opcode == TPU::BR)
return;
MachineBasicBlock &MBB = *I.getParent();
// Converting an unconditional branch to conditional.
if (Opcode == TPU::BR) {
auto NewBR = BuildMI(MBB, I, I.getDebugLoc(), TII->get(TPU::BRcond))
.add(I.getOperand(0));
P.addTo(&NewBR);
I.eraseFromParent();
return;
}
// Converting a conditional branch to unconditional.
if (P.isAlways()) {
BuildMI(MBB, I, I.getDebugLoc(), TII->get(TPU::BR)).add(I.getOperand(0));
I.eraseFromParent();
return;
}
// Applying new conditional predicate.
P.applyTo(&I);
};
for (auto &MBB : MF) {
std::optional<TPUPredicate> InverseBranchPredicate;
for (auto I = MBB.begin(), E = MBB.end(); I != E;) {
auto &MI = (*I++);
if (InverseBranchPredicate.has_value()) {
// FIXME: We have to disallow over-predicating branch instruction
// because it may lead to placing it in a delay slot of another branch
// which is not supported by some components in existing TensorCore
// infrastructure, namely the overlayer and DFC power verifier. We will
// have to either fix or rewrite these components, but have to disable
// this on TensorCore sub-tartgets until then. See b/141012999 for
// details and discussion.
if (MI.isPredicable() && TPUPredicate(MI).isAlways() &&
!(IsTensorCore && MI.isBranch())) {
ApplyPredicateToInstr(MI, InverseBranchPredicate.value());
}
if (MI.definesRegister(InverseBranchPredicate->getReg()))
// Predicate register has been clobbered.
InverseBranchPredicate.reset();
}
if (MI.getOpcode() == TPU::BR || MI.getOpcode() == TPU::BRcond)
InverseBranchPredicate = TPUPredicate(&MI).toggleInvert();
}
}
return true;
}
bool UnderPredicate::runOnMachineFunction(MachineFunction &MF) {
ST = &MF.getSubtarget();
auto TII = ST->getInstrInfo();
if (BrCondForTest) {
for (auto &MBB : MF) {
for (auto &MI : MBB) {
if (MI.getOpcode() == TPU::BRcondT)
MI.setDesc(TII->get(TPU::BRcond));
}
}
}
for (auto &MBB : MF) {
std::optional<TPUPredicate> InverseBranchPredicate;
for (auto &MI : MBB) {
if (MI.getOpcode() == TPU::BR || MI.getOpcode() == TPU::BRcond) {
InverseBranchPredicate = TPUPredicate(&MI).toggleInvert();
continue;
}
if (InverseBranchPredicate.has_value()) {
if (TPUPredicate(&MI) == *InverseBranchPredicate) {
// This instruction has the inverse predicate of the preceding branch.
// Switch it to always because the predicate is unnecessary.
TPUPredicate().applyTo(&MI);
}
if (MI.definesRegister(InverseBranchPredicate->getReg()))
// Predicate register has been clobbered.
InverseBranchPredicate.reset();
}
}
}
return true;
}
Pass *llvm::createTPUOverPredicatePass() { return new OverPredicate(); }
Pass *llvm::createTPUUnderPredicatePass() { return new UnderPredicate(); }