blob: b55977b36bcbec3b5d94df7b99f6b7a83cff58f3 [file] [log] [blame]
//===-- TPUXLUOptimizations.cpp - Optimizations over XLU ops ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Optimizations over XLU operations.
//
//===----------------------------------------------------------------------===//
#include "TPU.h"
#include "TPUAliasSetTracker.h"
#include "TPUIRUtils.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SparseBitVector.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/MemoryLocation.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/MachineInstr.h"
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsTPU.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Operator.h"
#include "llvm/IR/Verifier.h"
#include <algorithm>
#define DEBUG_TYPE "tpu-xlu-opt"
using namespace llvm;
using namespace llvm::TPU;
namespace {
cl::opt<bool> TPUXLUOptsPrintDeps(
"tpu-xlu-opts-print-deps", cl::init(false),
cl::desc(
"When debug is enabled print XLU dependencies between instructions."));
using DependenciesSet = SparseBitVector<>;
// Custom alias set that tracks dependencies of XLU operations per alias set.
class TPUAliasSetWithDep : public TPUAliasSet {
DependenciesSet Deps;
public:
DependenciesSet &getDeps() { return Deps; };
void merge(TPUAliasSet &&Other) override;
};
void TPUAliasSetWithDep::merge(TPUAliasSet &&Other) {
Deps |= static_cast<TPUAliasSetWithDep &&>(Other).Deps;
TPUAliasSet::merge(std::move(Other));
}
using TPUAliasSetTrackerWithDep = TPUAliasSetTracker<TPUAliasSetWithDep>;
class XLUDepGraph;
// Node reprenting an XLU operation the algorithm can operate on.
// TODO(maggioni): In the first implementation I was tracking predecessors,
// but the current algorithm is not using them, so I removed them for the time
// being to make the Node class lighter.
class Node {
enum class NodeType {
Transpose,
Rotate,
};
public:
using NodeIdx = unsigned;
friend class XLUDepGraph;
ArrayRef<Instruction *> getPushes() const { return PushSequence; }
ArrayRef<Instruction *> getPops() const { return ReturnSequence; }
ArrayRef<NodeIdx> getSuccs() const { return Succs; }
unsigned getId() const { return Id; }
bool canMerge() const { return CanMerge; }
// Check if two nodes are compatible for merging.
bool canMergeWith(const Node &N) const;
NodeType getNodeType() const { return Type; }
Node(NodeType Ty, unsigned Id) : Type(Ty), Id(Id), CanMerge(true) {}
private:
std::vector<Instruction *> PushSequence;
std::vector<Instruction *> ReturnSequence;
SmallVector<NodeIdx, 2> Succs;
NodeType Type;
unsigned Id;
bool CanMerge;
};
bool Node::canMergeWith(const Node &N) const {
// Nodes need to be mergeable
if (!canMerge() || !N.canMerge())
return false;
// Nodes need to be of the same type.
if (getNodeType() != N.getNodeType())
return false;
switch (getNodeType()) {
case Node::NodeType::Rotate: {
// Rotate nodes need to have the same operands that are not the value to
// be rotated.
assert(getPushes().size() == 1 && N.getPushes().size() == 1);
if (TPU::getRotateAmount(*getPushes()[0]) !=
TPU::getRotateAmount(*N.getPushes()[0]))
return false;
if (TPU::getRotateBusIdx(*getPushes()[0]) !=
TPU::getRotateBusIdx(*N.getPushes()[0]))
return false;
break;
}
default:
break;
}
// If the two nodes have a dependency bail.
return std::find(Succs.begin(), Succs.end(), N.getId()) == Succs.end();
}
// Tracks dependencies of instructions with XLU nodes.
class DependenciesTracker {
// Map tracking Dependencies of XLU nodes for every instruction.
DenseMap<const Instruction *, DependenciesSet> DepsPerInstr;
public:
const DependenciesSet *getInstrDeps(const Instruction *I) const {
auto It = DepsPerInstr.find(I);
if (It != DepsPerInstr.end())
return &It->second;
return nullptr;
}
void addDepsForInstr(const Instruction *I, const DependenciesSet &DepSet) {
auto &DepsI = DepsPerInstr[I];
DepsI |= DepSet;
}
void trackDeps(const Instruction *I, const XLUDepGraph &Graph,
TPUAliasSetTrackerWithDep &AliasTracker);
void clear() { DepsPerInstr.clear(); }
};
class XLUDepGraph {
using NodeVector = std::vector<Node>;
public:
using Iterator = NodeVector::iterator;
private:
NodeVector Nodes;
DenseMap<const Instruction *, Node::NodeIdx> InstrSequenceId;
DependenciesTracker DepTracker;
public:
Iterator begin() { return Nodes.begin(); }
Iterator end() { return Nodes.end(); }
void buildGraph(BasicBlock &BB);
void addPred(Node::NodeIdx Target, Node::NodeIdx Pred) {
Nodes[Pred].Succs.push_back(Target);
}
const DependenciesSet *getDepForInstr(const Instruction *I) const;
unsigned size() const { return Nodes.size(); }
Node &getNodeFromId(Node::NodeIdx Idx) {
assert(Idx < Nodes.size() && "Out of bounds");
return Nodes[Idx];
}
std::optional<Node::NodeIdx> getInstrSequenceId(const Instruction *I) const {
auto InstrIt = InstrSequenceId.find(I);
if (InstrIt == InstrSequenceId.end())
return std::nullopt;
return InstrIt->second;
}
};
// Add XLU dependencies for an instruction
void DependenciesTracker::trackDeps(const Instruction *I,
const XLUDepGraph &Graph,
TPUAliasSetTrackerWithDep &AliasTracker) {
DependenciesSet &NewDeps = DepsPerInstr[I];
for (auto &U : I->operands()) {
const Instruction *UI = dyn_cast<Instruction>(U.get());
if (!UI)
continue;
auto DepsIt = DepsPerInstr.find(UI);
if (DepsIt != DepsPerInstr.end()) {
for (auto NIdx : DepsIt->second)
NewDeps.set(NIdx);
}
}
if (!I->mayReadOrWriteMemory() || Graph.getInstrSequenceId(I).has_value()) {
LLVM_DEBUG(if (TPUXLUOptsPrintDeps) {
auto SeqId = Graph.getInstrSequenceId(I);
dbgs() << "Deps for: ";
if (SeqId.has_value())
dbgs() << "Seq " << *SeqId << " ";
dbgs() << " " << *I << " - ";
for (auto D : NewDeps) {
dbgs() << D << " ";
}
dbgs() << "\n";
});
return;
}
// Call back to be called when an alias set is first found to be aliased.
auto AddSetDepsToInstr = [&](TPUAliasSet *AS) {
TPUAliasSetWithDep *ASDep = static_cast<TPUAliasSetWithDep *>(AS);
NewDeps |= ASDep->getDeps();
};
// Callback to be called when the memory operation has been added to an
// alias set.
auto AddInstrDepsToSet = [&](TPUAliasSet *AS) {
TPUAliasSetWithDep *ASDep = static_cast<TPUAliasSetWithDep *>(AS);
ASDep->getDeps() = NewDeps;
};
if (NewDeps.empty()) {
// We want to track only operations that are dependent on XLU ops that
// we care about. So, if this instruction is not dependent yet check
// that is dependent on any memory operation we are tracking (if we are
// tracking it that means it is dependent on an XLU op) and if it is add
// it to the tracker.
if (AliasTracker.aliasQuery(I, /*AddToTracker*/ false) !=
AliasResult::NoAlias)
AliasTracker.aliasQuery(I, /*AddToTracker*/ true, AddSetDepsToInstr,
AddInstrDepsToSet);
} else {
AliasTracker.aliasQuery(I, /*AddToTracker*/ true, AddSetDepsToInstr,
AddInstrDepsToSet);
}
LLVM_DEBUG(if (TPUXLUOptsPrintDeps) {
auto SeqId = Graph.getInstrSequenceId(I);
dbgs() << "Deps for: ";
if (SeqId.has_value())
dbgs() << "Seq " << *SeqId;
dbgs() << " " << *I << " - ";
for (auto D : NewDeps) {
dbgs() << D << " ";
}
dbgs() << "\n";
});
}
const DependenciesSet *XLUDepGraph::getDepForInstr(const Instruction *I) const {
return DepTracker.getInstrDeps(I);
}
void XLUDepGraph::buildGraph(BasicBlock &BB) {
// Alias tracker to check if memory operations are dependent with one another.
TPUAliasSetTrackerWithDep AliasTracker(BB.getModule()->getDataLayout());
DepTracker.clear();
bool FoundXLUSequence = false;
DenseMap<Node::NodeIdx, DependenciesSet> DepsAdded;
for (auto &I : BB) {
// Tracking only unpacked transposes as packed transposes are already
// in the form we are trying to transform this to.
const bool IsTransposePush = isTransposePushNotPacked(I);
const bool IsRotatePush = isRotatePushNotPacked(I);
// Skip instructions until we found a transpose.
if (!FoundXLUSequence && !IsTransposePush && !IsRotatePush)
continue;
if (IsTransposePush || IsRotatePush) {
const Instruction *PreviousPush =
IsTransposePush ? TPU::getPreviousTransposePush(I, false) : nullptr;
Node *SequenceNode;
// Found a new transpose sequence. Create a new node.
if (PreviousPush == nullptr) {
FoundXLUSequence = true;
Nodes.emplace_back(IsTransposePush ? Node::NodeType::Transpose
: Node::NodeType::Rotate,
Nodes.size());
SequenceNode = &Nodes.back();
} else {
assert(InstrSequenceId.count(PreviousPush) &&
"Previous push instruction has no sequence id assigned");
SequenceNode = &Nodes[InstrSequenceId[PreviousPush]];
}
InstrSequenceId[&I] = SequenceNode->getId();
DepTracker.trackDeps(&I, *this, AliasTracker);
SequenceNode->PushSequence.push_back(&I);
// Different instructions of the transpose sequence could have different
// dependencies, so we need to try to add edges for all of them.
auto *Deps = DepTracker.getInstrDeps(&I);
if (Deps != nullptr) {
auto &CurrentDepsAdded = DepsAdded[SequenceNode->getId()];
for (auto NIdx : *Deps) {
if (CurrentDepsAdded.test_and_set(NIdx)) {
addPred(SequenceNode->getId(), NIdx);
}
}
}
// If this is the end of the sequence add all dependencies of pushes
// to the pops and set the pops as part of the sequence.
if (isTransposeEnd(I) || IsRotatePush) {
auto DepIt = DepsAdded.find(SequenceNode->getId());
if (DepIt != DepsAdded.end()) {
DepIt->second.set(SequenceNode->getId());
for (auto *U : I.users()) {
Instruction *UI = cast<Instruction>(U);
assert(isXLUPop(*UI) &&
"Expected uses of transpose end to be pops");
InstrSequenceId[UI] = SequenceNode->getId();
DepTracker.addDepsForInstr(UI, DepIt->second);
}
assert(!I.user_empty());
DepsAdded.erase(DepIt);
}
}
} else if (isXLUPop(I)) {
// This is an XLU pop, but it might be not part of a transpose sequence.
// If it is add it to the return sequence vector and track it.
auto SeqIt = InstrSequenceId.find(&I);
if (SeqIt != InstrSequenceId.end())
Nodes[SeqIt->second].ReturnSequence.push_back(&I);
DepTracker.trackDeps(&I, *this, AliasTracker);
} else {
DepTracker.trackDeps(&I, *this, AliasTracker);
}
}
// We need to check if the nodes are mergeable. A node is mergeable only
// if its input can be truncated to BF16 without loss of precision.
// TODO(maggioni): This is simpler than what LLO does right now.
// We don't try to prove that ANDs or shifts clear away bits of precision
// and we don't try to look through between loads and stores for now and
// only consider stores (pessimizing the analysis.
DenseMap<const Instruction *, bool> CanProduceBF16Map;
for (auto It = BB.rbegin(), E = BB.rend(); It != E; ++It) {
// If this is a transpose push then if it can produce BF16 depends on what
// we determined about the uses when we checked the pops (because we are
// iterating backwards we visit the pops before the pushes).
if (InstrSequenceId.count(&*It) && !isXLUPop(*It)) {
assert(TPU::isTransposePushNotPacked(*It) ||
TPU::isRotatePushNotPacked(*It));
CanProduceBF16Map[&*It] = Nodes[InstrSequenceId[&*It]].canMerge();
continue;
}
CanProduceBF16Map[&*It] = true;
// Check all the uses and find uses that potentially need full precision.
for (auto &U : It->uses()) {
const Instruction *UI = cast<Instruction>(U.getUser());
if (reducesPrecisionToBF16(*UI, U.getOperandNo()))
continue;
auto PreserveOperands = preservesBF16OperandPrecision(*UI);
// Not checking value that escape the block or might be used by PHIs (like
// in a loop). Pessimize in that case.
if (UI->getParent() != &BB || !CanProduceBF16Map.count(UI) ||
!(CanProduceBF16Map[UI] &&
std::find(PreserveOperands.begin(), PreserveOperands.end(),
U.getOperandNo()) != PreserveOperands.end())) {
if (isXLUPop(*It)) {
auto NodeIt = InstrSequenceId.find(&*It);
// Transpose pops can be used to pop other TRF things, so check this
// is actually part of a sequence.
if (NodeIt != InstrSequenceId.end())
Nodes[NodeIt->second].CanMerge = false;
}
// We might need full precision. Set CanProduceBF16 to false for the
// instruction.
CanProduceBF16Map[&*It] = false;
break;
}
}
}
LLVM_DEBUG(dbgs() << "Found " << Nodes.size() << " nodes\n"; int NNum = 0;
for (auto &N
: Nodes) {
dbgs() << "Node " << NNum << "\n";
dbgs() << "\tSuccs: ";
for (auto S : N.Succs) {
dbgs() << S << " ";
}
dbgs() << "CanMerge: " << N.canMerge() << "\n";
++NNum;
});
}
// Class that performs the merge of XLU transpose ops.
class TPUXLUBF16Merger {
XLUDepGraph Graph;
BasicBlock &BB;
const DataLayout &DL;
XLUDepGraph::Iterator mergeNodes(XLUDepGraph::Iterator N1,
XLUDepGraph::Iterator N2);
public:
TPUXLUBF16Merger(BasicBlock &BB)
: BB(BB), DL(BB.getModule()->getDataLayout()) {}
bool run();
};
XLUDepGraph::Iterator TPUXLUBF16Merger::mergeNodes(XLUDepGraph::Iterator N1It,
XLUDepGraph::Iterator N2It) {
auto N1 = *N1It;
auto N2 = *N2It;
// Keep track of dependent nodes of N2 that we had to move before N1 to merge
// N2 with N1.
BitVector MovedNodes(Graph.size());
unsigned MovedInstructions = 0;
unsigned MovedNodesCount = 0;
SetVector<Instruction *> Deps(N2.getPushes().begin(), N2.getPushes().end());
BasicBlock::reverse_iterator It = N2.getPops().back()->getReverseIterator();
BasicBlock::reverse_iterator EndIt = N1.getPushes()[0]->getReverseIterator();
Instruction *InsertBefore = &*EndIt;
TPUAliasSetTracker<> TAT(DL);
// Helper to move instructions that moves XLU sequence in a block.
auto MoveInstruction = [&](Instruction *I, std::optional<uint32_t> SeqId) {
// If this is a sequence lets move the whole sequence
if (SeqId.has_value()) {
MovedNodes.set(SeqId.value());
while (It != EndIt) {
auto ItSeq = Graph.getInstrSequenceId(&*It);
if (!ItSeq.has_value())
break;
if (ItSeq.value() != SeqId.value())
break;
++It;
}
auto &SeqNode = Graph.getNodeFromId(SeqId.value());
for (auto *P : SeqNode.getPushes()) {
P->moveBefore(InsertBefore);
Deps.insert(P);
}
for (auto *P : SeqNode.getPops()) {
P->moveBefore(InsertBefore);
Deps.insert(P);
}
MovedInstructions += SeqNode.getPushes().size();
MovedInstructions += SeqNode.getPops().size();
++MovedNodesCount;
InsertBefore = SeqNode.getPushes()[0];
return;
}
I->moveBefore(InsertBefore);
InsertBefore = I;
Deps.insert(I);
++MovedInstructions;
};
unsigned N1Id = N1.getId();
unsigned N2Id = N2.getId();
// Move instruction for merging.
while (It != EndIt) {
Instruction *I = (&*It++);
auto SeqId = Graph.getInstrSequenceId(I);
// Do not move instructions of the two nodes we are merging.
if (SeqId.has_value() && (SeqId.value() == N1Id || SeqId.value() == N2Id))
continue;
bool AddedToDeps = false;
// If any of the users of the current instruction is one of the instructions
// we determined we have to move then move this instruction as well.
for (auto *U : I->users()) {
Instruction *UI = cast<Instruction>(U);
if (Deps.count(UI)) {
AddedToDeps = true;
assert(
(!Graph.getDepForInstr(I) ||
!Graph.getDepForInstr(I)->test(N1.getId())) &&
"Adding dep that make us dependent on node we want to merge with");
MoveInstruction(I, SeqId);
break;
}
}
// Check for dependencies with previously moved instructions through
// memory.
if (!I->mayReadOrWriteMemory())
continue;
// We don't consider handled XLU operations as aliasing with other
// inaccessible mem intrinsics.
if (SeqId.has_value())
continue;
// If this instruction doesn't alias then we are good and we don't need
// to move it otherwise move it and track it. If the instruction is already
// in Deps then we don't need to move it. Just track it.
if (!AddedToDeps && TAT.aliasQuery(I, false) != AliasResult::NoAlias) {
assert((!Graph.getDepForInstr(I) ||
!Graph.getDepForInstr(I)->test(N1.getId())) &&
"Adding dep that make us dependent on node we want to merge with");
MoveInstruction(I, SeqId);
}
if (Deps.count(I))
TAT.aliasQuery(I, true);
}
LLVM_DEBUG(dbgs() << "Moved instructions: " << MovedInstructions << "\n");
assert(N1.getPushes().size() == N2.getPushes().size());
// Merge the two XLU transpose sequences.
IRBuilder<> Builder(BB.getContext());
auto ForceToFloat = [&Builder](Value *V) {
assert(V->getType()->isVectorTy());
Value *Float = V;
if (!Float->getType()->isFPOrFPVectorTy()) {
Float = Builder.CreateBitCast(
Float, VectorType::get(
Builder.getFloatTy(),
cast<VectorType>(Float->getType())->getElementCount()));
}
return Float;
};
auto ForceToTypeOf = [&Builder](Value *V, Value *Of) {
assert(V->getType()->isVectorTy());
assert(Of->getType()->isVectorTy());
Value *Result = V;
if (Result->getType() != Of->getType()) {
Result = Builder.CreateBitCast(Result, Of->getType());
}
return Result;
};
Type *FloatVecTy = VectorType::get(Builder.getFloatTy(), 1024, false);
Function *PackIntr = llvm::Intrinsic::getDeclaration(
BB.getModule(), llvm::Intrinsic::tpu_pack, {FloatVecTy});
for (int I = N1.getPushes().size() - 1; I >= 0; --I) {
Instruction *Push1 = N1.getPushes()[I];
Instruction *Push2 = N2.getPushes()[I];
Builder.SetInsertPoint(N1.getPushes()[I]);
Instruction *Pack =
Builder.CreateCall(PackIntr, {ForceToFloat(Push2->getOperand(0)),
ForceToFloat(Push1->getOperand(0))});
Push1->replaceUsesOfWith(Push1->getOperand(0),
ForceToTypeOf(Pack, Push1->getOperand(0)));
Instruction *Pop1 = N1.getPops()[I];
Instruction *Pop2 = N2.getPops()[I];
Builder.SetInsertPoint(&BB, std::next(Pop1->getIterator()));
Instruction *UnpackL = cast<Instruction>(Builder.CreateShl(
Pop1, ConstantDataVector::getSplat(1024, Builder.getInt32(16))));
Instruction *UnpackU = cast<Instruction>(Builder.CreateAnd(
Pop1,
ConstantDataVector::getSplat(1024, Builder.getInt32(0xFFFF0000U))));
Pop1->replaceUsesWithIf(UnpackL, [UnpackL, UnpackU](Use &U) {
return U.getUser() != UnpackL && U.getUser() != UnpackU;
});
Pop2->replaceUsesWithIf(
UnpackU, [UnpackU](Use &U) { return U.getUser() != UnpackU; });
Pop2->eraseFromParent();
}
for (int I = N2.getPushes().size() - 1; I >= 0; --I) {
N2.getPushes()[I]->eraseFromParent();
}
auto NewIt = N1It;
if (MovedNodesCount > 0) {
NewIt = std::stable_partition(
N1It, N2It, [&](const Node &N) { return MovedNodes.test(N.getId()); });
}
LLVM_DEBUG(dbgs() << "Post reorder: "; for (auto &N
: Graph) {
dbgs() << N.getId() << " ";
} dbgs() << "\n";);
return NewIt;
}
bool TPUXLUBF16Merger::run() {
Graph.buildGraph(BB);
if (Graph.begin() == Graph.end())
return false;
auto It = Graph.begin();
auto E = std::prev(Graph.end());
BitVector Merged(Graph.size());
unsigned MergedNodes = 0;
bool Changed = false;
// We currently naively iterate the graph in node order and try to merge
// eagerly with the first node we find
// TODO(maggioni): Investigate fancier algorithms.
while (It != E) {
auto &N = *It;
// If the node cannot be merged (preserve precision) or has been merged
// already continue.
if (!N.canMerge() || Merged.test(N.getId())) {
++It;
continue;
}
// Helper to evaluate if a node can be merged with the current node.
// Check if the node ToMerge can be merged, if it is of the same type of
// the current node, if it hasn't been merged already and if it is not one
// of the successors of the current node.
auto ValidForMerge = [&](const Node &ToMerge) {
return !Merged.test(ToMerge.getId()) && N.canMergeWith(ToMerge);
};
auto ToMergeIt = std::find_if(std::next(It), Graph.end(), ValidForMerge);
if (ToMergeIt == Graph.end()) {
++It;
// Didn't find anything to merge with
continue;
}
// Perform the merge
++MergedNodes;
LLVM_DEBUG(dbgs() << "Merging Node " << N.getId() << " with node "
<< ToMergeIt->getId() << "\n");
Merged.set(N.getId());
Merged.set(ToMergeIt->getId());
It = mergeNodes(It, ToMergeIt);
Changed = true;
}
(void)MergedNodes;
LLVM_DEBUG(dbgs() << "Num of merged nodes: " << MergedNodes << "\n");
return Changed;
}
class TPUXLUOptimizations : public FunctionPass {
public:
static char ID;
TPUXLUOptimizations() : FunctionPass(ID) {}
bool runOnFunction(Function &F) override;
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.setPreservesCFG();
}
StringRef getPassName() const override { return "TPU XLU optimizations"; }
private:
bool processBasicBlock(BasicBlock &BB);
};
char TPUXLUOptimizations::ID = 0;
} // namespace
INITIALIZE_PASS(TPUXLUOptimizations, DEBUG_TYPE, "TPU XLU optimizations", false,
false)
Pass *llvm::createTPUXLUOptimizationsPass() {
return new TPUXLUOptimizations();
}
bool TPUXLUOptimizations::processBasicBlock(BasicBlock &BB) {
return TPUXLUBF16Merger(BB).run();
}
bool TPUXLUOptimizations::runOnFunction(Function &F) {
bool Changed = false;
for (auto &BB : F) {
Changed |= processBasicBlock(BB);
}
// After pass verification of valid IR.
assert(!verifyFunction(F, &dbgs()));
return Changed;
}