| //===-- TPUXLUOptimizations.cpp - Optimizations over XLU ops ---*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // Optimizations over XLU operations. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "TPU.h" |
| #include "TPUAliasSetTracker.h" |
| #include "TPUIRUtils.h" |
| #include "llvm/ADT/Optional.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/ADT/SparseBitVector.h" |
| #include "llvm/Analysis/AliasAnalysis.h" |
| #include "llvm/Analysis/MemoryLocation.h" |
| #include "llvm/Analysis/ValueTracking.h" |
| #include "llvm/CodeGen/MachineInstr.h" |
| #include "llvm/CodeGen/TargetPassConfig.h" |
| #include "llvm/IR/IRBuilder.h" |
| #include "llvm/IR/Instructions.h" |
| #include "llvm/IR/IntrinsicInst.h" |
| #include "llvm/IR/Intrinsics.h" |
| #include "llvm/IR/IntrinsicsTPU.h" |
| #include "llvm/IR/Metadata.h" |
| #include "llvm/IR/Operator.h" |
| #include "llvm/IR/Verifier.h" |
| |
| #include <algorithm> |
| |
| #define DEBUG_TYPE "tpu-xlu-opt" |
| |
| using namespace llvm; |
| using namespace llvm::TPU; |
| |
| namespace { |
| |
| cl::opt<bool> TPUXLUOptsPrintDeps( |
| "tpu-xlu-opts-print-deps", cl::init(false), |
| cl::desc( |
| "When debug is enabled print XLU dependencies between instructions.")); |
| |
| using DependenciesSet = SparseBitVector<>; |
| |
| // Custom alias set that tracks dependencies of XLU operations per alias set. |
| class TPUAliasSetWithDep : public TPUAliasSet { |
| DependenciesSet Deps; |
| |
| public: |
| DependenciesSet &getDeps() { return Deps; }; |
| void merge(TPUAliasSet &&Other) override; |
| }; |
| |
| void TPUAliasSetWithDep::merge(TPUAliasSet &&Other) { |
| Deps |= static_cast<TPUAliasSetWithDep &&>(Other).Deps; |
| TPUAliasSet::merge(std::move(Other)); |
| } |
| |
| using TPUAliasSetTrackerWithDep = TPUAliasSetTracker<TPUAliasSetWithDep>; |
| |
| class XLUDepGraph; |
| |
| // Node reprenting an XLU operation the algorithm can operate on. |
| // TODO(maggioni): In the first implementation I was tracking predecessors, |
| // but the current algorithm is not using them, so I removed them for the time |
| // being to make the Node class lighter. |
| class Node { |
| enum class NodeType { |
| Transpose, |
| Rotate, |
| }; |
| |
| public: |
| using NodeIdx = unsigned; |
| friend class XLUDepGraph; |
| ArrayRef<Instruction *> getPushes() const { return PushSequence; } |
| ArrayRef<Instruction *> getPops() const { return ReturnSequence; } |
| ArrayRef<NodeIdx> getSuccs() const { return Succs; } |
| unsigned getId() const { return Id; } |
| bool canMerge() const { return CanMerge; } |
| // Check if two nodes are compatible for merging. |
| bool canMergeWith(const Node &N) const; |
| NodeType getNodeType() const { return Type; } |
| |
| Node(NodeType Ty, unsigned Id) : Type(Ty), Id(Id), CanMerge(true) {} |
| |
| private: |
| std::vector<Instruction *> PushSequence; |
| std::vector<Instruction *> ReturnSequence; |
| SmallVector<NodeIdx, 2> Succs; |
| NodeType Type; |
| unsigned Id; |
| bool CanMerge; |
| }; |
| |
| bool Node::canMergeWith(const Node &N) const { |
| // Nodes need to be mergeable |
| if (!canMerge() || !N.canMerge()) |
| return false; |
| // Nodes need to be of the same type. |
| if (getNodeType() != N.getNodeType()) |
| return false; |
| switch (getNodeType()) { |
| case Node::NodeType::Rotate: { |
| // Rotate nodes need to have the same operands that are not the value to |
| // be rotated. |
| assert(getPushes().size() == 1 && N.getPushes().size() == 1); |
| if (TPU::getRotateAmount(*getPushes()[0]) != |
| TPU::getRotateAmount(*N.getPushes()[0])) |
| return false; |
| if (TPU::getRotateBusIdx(*getPushes()[0]) != |
| TPU::getRotateBusIdx(*N.getPushes()[0])) |
| return false; |
| break; |
| } |
| default: |
| break; |
| } |
| // If the two nodes have a dependency bail. |
| return std::find(Succs.begin(), Succs.end(), N.getId()) == Succs.end(); |
| } |
| |
| // Tracks dependencies of instructions with XLU nodes. |
| class DependenciesTracker { |
| // Map tracking Dependencies of XLU nodes for every instruction. |
| DenseMap<const Instruction *, DependenciesSet> DepsPerInstr; |
| |
| public: |
| const DependenciesSet *getInstrDeps(const Instruction *I) const { |
| auto It = DepsPerInstr.find(I); |
| if (It != DepsPerInstr.end()) |
| return &It->second; |
| return nullptr; |
| } |
| void addDepsForInstr(const Instruction *I, const DependenciesSet &DepSet) { |
| auto &DepsI = DepsPerInstr[I]; |
| DepsI |= DepSet; |
| } |
| void trackDeps(const Instruction *I, const XLUDepGraph &Graph, |
| TPUAliasSetTrackerWithDep &AliasTracker); |
| void clear() { DepsPerInstr.clear(); } |
| }; |
| |
| class XLUDepGraph { |
| using NodeVector = std::vector<Node>; |
| |
| public: |
| using Iterator = NodeVector::iterator; |
| |
| private: |
| NodeVector Nodes; |
| DenseMap<const Instruction *, Node::NodeIdx> InstrSequenceId; |
| DependenciesTracker DepTracker; |
| |
| public: |
| Iterator begin() { return Nodes.begin(); } |
| Iterator end() { return Nodes.end(); } |
| void buildGraph(BasicBlock &BB); |
| void addPred(Node::NodeIdx Target, Node::NodeIdx Pred) { |
| Nodes[Pred].Succs.push_back(Target); |
| } |
| const DependenciesSet *getDepForInstr(const Instruction *I) const; |
| unsigned size() const { return Nodes.size(); } |
| Node &getNodeFromId(Node::NodeIdx Idx) { |
| assert(Idx < Nodes.size() && "Out of bounds"); |
| return Nodes[Idx]; |
| } |
| std::optional<Node::NodeIdx> getInstrSequenceId(const Instruction *I) const { |
| auto InstrIt = InstrSequenceId.find(I); |
| if (InstrIt == InstrSequenceId.end()) |
| return std::nullopt; |
| return InstrIt->second; |
| } |
| }; |
| |
| // Add XLU dependencies for an instruction |
| void DependenciesTracker::trackDeps(const Instruction *I, |
| const XLUDepGraph &Graph, |
| TPUAliasSetTrackerWithDep &AliasTracker) { |
| DependenciesSet &NewDeps = DepsPerInstr[I]; |
| for (auto &U : I->operands()) { |
| const Instruction *UI = dyn_cast<Instruction>(U.get()); |
| if (!UI) |
| continue; |
| auto DepsIt = DepsPerInstr.find(UI); |
| if (DepsIt != DepsPerInstr.end()) { |
| for (auto NIdx : DepsIt->second) |
| NewDeps.set(NIdx); |
| } |
| } |
| |
| if (!I->mayReadOrWriteMemory() || Graph.getInstrSequenceId(I).has_value()) { |
| LLVM_DEBUG(if (TPUXLUOptsPrintDeps) { |
| auto SeqId = Graph.getInstrSequenceId(I); |
| dbgs() << "Deps for: "; |
| if (SeqId.has_value()) |
| dbgs() << "Seq " << *SeqId << " "; |
| dbgs() << " " << *I << " - "; |
| for (auto D : NewDeps) { |
| dbgs() << D << " "; |
| } |
| dbgs() << "\n"; |
| }); |
| return; |
| } |
| |
| // Call back to be called when an alias set is first found to be aliased. |
| auto AddSetDepsToInstr = [&](TPUAliasSet *AS) { |
| TPUAliasSetWithDep *ASDep = static_cast<TPUAliasSetWithDep *>(AS); |
| NewDeps |= ASDep->getDeps(); |
| }; |
| // Callback to be called when the memory operation has been added to an |
| // alias set. |
| auto AddInstrDepsToSet = [&](TPUAliasSet *AS) { |
| TPUAliasSetWithDep *ASDep = static_cast<TPUAliasSetWithDep *>(AS); |
| ASDep->getDeps() = NewDeps; |
| }; |
| if (NewDeps.empty()) { |
| // We want to track only operations that are dependent on XLU ops that |
| // we care about. So, if this instruction is not dependent yet check |
| // that is dependent on any memory operation we are tracking (if we are |
| // tracking it that means it is dependent on an XLU op) and if it is add |
| // it to the tracker. |
| if (AliasTracker.aliasQuery(I, /*AddToTracker*/ false) != |
| AliasResult::NoAlias) |
| AliasTracker.aliasQuery(I, /*AddToTracker*/ true, AddSetDepsToInstr, |
| AddInstrDepsToSet); |
| } else { |
| AliasTracker.aliasQuery(I, /*AddToTracker*/ true, AddSetDepsToInstr, |
| AddInstrDepsToSet); |
| } |
| LLVM_DEBUG(if (TPUXLUOptsPrintDeps) { |
| auto SeqId = Graph.getInstrSequenceId(I); |
| dbgs() << "Deps for: "; |
| if (SeqId.has_value()) |
| dbgs() << "Seq " << *SeqId; |
| dbgs() << " " << *I << " - "; |
| for (auto D : NewDeps) { |
| dbgs() << D << " "; |
| } |
| dbgs() << "\n"; |
| }); |
| } |
| |
| const DependenciesSet *XLUDepGraph::getDepForInstr(const Instruction *I) const { |
| return DepTracker.getInstrDeps(I); |
| } |
| |
| void XLUDepGraph::buildGraph(BasicBlock &BB) { |
| // Alias tracker to check if memory operations are dependent with one another. |
| TPUAliasSetTrackerWithDep AliasTracker(BB.getModule()->getDataLayout()); |
| DepTracker.clear(); |
| bool FoundXLUSequence = false; |
| DenseMap<Node::NodeIdx, DependenciesSet> DepsAdded; |
| for (auto &I : BB) { |
| // Tracking only unpacked transposes as packed transposes are already |
| // in the form we are trying to transform this to. |
| const bool IsTransposePush = isTransposePushNotPacked(I); |
| const bool IsRotatePush = isRotatePushNotPacked(I); |
| // Skip instructions until we found a transpose. |
| if (!FoundXLUSequence && !IsTransposePush && !IsRotatePush) |
| continue; |
| if (IsTransposePush || IsRotatePush) { |
| const Instruction *PreviousPush = |
| IsTransposePush ? TPU::getPreviousTransposePush(I, false) : nullptr; |
| Node *SequenceNode; |
| // Found a new transpose sequence. Create a new node. |
| if (PreviousPush == nullptr) { |
| FoundXLUSequence = true; |
| Nodes.emplace_back(IsTransposePush ? Node::NodeType::Transpose |
| : Node::NodeType::Rotate, |
| Nodes.size()); |
| SequenceNode = &Nodes.back(); |
| } else { |
| assert(InstrSequenceId.count(PreviousPush) && |
| "Previous push instruction has no sequence id assigned"); |
| SequenceNode = &Nodes[InstrSequenceId[PreviousPush]]; |
| } |
| InstrSequenceId[&I] = SequenceNode->getId(); |
| DepTracker.trackDeps(&I, *this, AliasTracker); |
| SequenceNode->PushSequence.push_back(&I); |
| // Different instructions of the transpose sequence could have different |
| // dependencies, so we need to try to add edges for all of them. |
| auto *Deps = DepTracker.getInstrDeps(&I); |
| if (Deps != nullptr) { |
| auto &CurrentDepsAdded = DepsAdded[SequenceNode->getId()]; |
| for (auto NIdx : *Deps) { |
| if (CurrentDepsAdded.test_and_set(NIdx)) { |
| addPred(SequenceNode->getId(), NIdx); |
| } |
| } |
| } |
| // If this is the end of the sequence add all dependencies of pushes |
| // to the pops and set the pops as part of the sequence. |
| if (isTransposeEnd(I) || IsRotatePush) { |
| auto DepIt = DepsAdded.find(SequenceNode->getId()); |
| if (DepIt != DepsAdded.end()) { |
| DepIt->second.set(SequenceNode->getId()); |
| for (auto *U : I.users()) { |
| Instruction *UI = cast<Instruction>(U); |
| assert(isXLUPop(*UI) && |
| "Expected uses of transpose end to be pops"); |
| InstrSequenceId[UI] = SequenceNode->getId(); |
| DepTracker.addDepsForInstr(UI, DepIt->second); |
| } |
| assert(!I.user_empty()); |
| DepsAdded.erase(DepIt); |
| } |
| } |
| } else if (isXLUPop(I)) { |
| // This is an XLU pop, but it might be not part of a transpose sequence. |
| // If it is add it to the return sequence vector and track it. |
| auto SeqIt = InstrSequenceId.find(&I); |
| if (SeqIt != InstrSequenceId.end()) |
| Nodes[SeqIt->second].ReturnSequence.push_back(&I); |
| DepTracker.trackDeps(&I, *this, AliasTracker); |
| } else { |
| DepTracker.trackDeps(&I, *this, AliasTracker); |
| } |
| } |
| // We need to check if the nodes are mergeable. A node is mergeable only |
| // if its input can be truncated to BF16 without loss of precision. |
| // TODO(maggioni): This is simpler than what LLO does right now. |
| // We don't try to prove that ANDs or shifts clear away bits of precision |
| // and we don't try to look through between loads and stores for now and |
| // only consider stores (pessimizing the analysis. |
| DenseMap<const Instruction *, bool> CanProduceBF16Map; |
| for (auto It = BB.rbegin(), E = BB.rend(); It != E; ++It) { |
| // If this is a transpose push then if it can produce BF16 depends on what |
| // we determined about the uses when we checked the pops (because we are |
| // iterating backwards we visit the pops before the pushes). |
| if (InstrSequenceId.count(&*It) && !isXLUPop(*It)) { |
| assert(TPU::isTransposePushNotPacked(*It) || |
| TPU::isRotatePushNotPacked(*It)); |
| CanProduceBF16Map[&*It] = Nodes[InstrSequenceId[&*It]].canMerge(); |
| continue; |
| } |
| CanProduceBF16Map[&*It] = true; |
| // Check all the uses and find uses that potentially need full precision. |
| for (auto &U : It->uses()) { |
| const Instruction *UI = cast<Instruction>(U.getUser()); |
| if (reducesPrecisionToBF16(*UI, U.getOperandNo())) |
| continue; |
| auto PreserveOperands = preservesBF16OperandPrecision(*UI); |
| |
| // Not checking value that escape the block or might be used by PHIs (like |
| // in a loop). Pessimize in that case. |
| if (UI->getParent() != &BB || !CanProduceBF16Map.count(UI) || |
| !(CanProduceBF16Map[UI] && |
| std::find(PreserveOperands.begin(), PreserveOperands.end(), |
| U.getOperandNo()) != PreserveOperands.end())) { |
| if (isXLUPop(*It)) { |
| auto NodeIt = InstrSequenceId.find(&*It); |
| // Transpose pops can be used to pop other TRF things, so check this |
| // is actually part of a sequence. |
| if (NodeIt != InstrSequenceId.end()) |
| Nodes[NodeIt->second].CanMerge = false; |
| } |
| // We might need full precision. Set CanProduceBF16 to false for the |
| // instruction. |
| CanProduceBF16Map[&*It] = false; |
| break; |
| } |
| } |
| } |
| LLVM_DEBUG(dbgs() << "Found " << Nodes.size() << " nodes\n"; int NNum = 0; |
| for (auto &N |
| : Nodes) { |
| dbgs() << "Node " << NNum << "\n"; |
| dbgs() << "\tSuccs: "; |
| for (auto S : N.Succs) { |
| dbgs() << S << " "; |
| } |
| dbgs() << "CanMerge: " << N.canMerge() << "\n"; |
| ++NNum; |
| }); |
| } |
| |
| // Class that performs the merge of XLU transpose ops. |
| class TPUXLUBF16Merger { |
| XLUDepGraph Graph; |
| BasicBlock &BB; |
| const DataLayout &DL; |
| |
| XLUDepGraph::Iterator mergeNodes(XLUDepGraph::Iterator N1, |
| XLUDepGraph::Iterator N2); |
| |
| public: |
| TPUXLUBF16Merger(BasicBlock &BB) |
| : BB(BB), DL(BB.getModule()->getDataLayout()) {} |
| |
| bool run(); |
| }; |
| |
| XLUDepGraph::Iterator TPUXLUBF16Merger::mergeNodes(XLUDepGraph::Iterator N1It, |
| XLUDepGraph::Iterator N2It) { |
| auto N1 = *N1It; |
| auto N2 = *N2It; |
| // Keep track of dependent nodes of N2 that we had to move before N1 to merge |
| // N2 with N1. |
| BitVector MovedNodes(Graph.size()); |
| unsigned MovedInstructions = 0; |
| unsigned MovedNodesCount = 0; |
| SetVector<Instruction *> Deps(N2.getPushes().begin(), N2.getPushes().end()); |
| BasicBlock::reverse_iterator It = N2.getPops().back()->getReverseIterator(); |
| BasicBlock::reverse_iterator EndIt = N1.getPushes()[0]->getReverseIterator(); |
| Instruction *InsertBefore = &*EndIt; |
| TPUAliasSetTracker<> TAT(DL); |
| // Helper to move instructions that moves XLU sequence in a block. |
| auto MoveInstruction = [&](Instruction *I, std::optional<uint32_t> SeqId) { |
| // If this is a sequence lets move the whole sequence |
| if (SeqId.has_value()) { |
| MovedNodes.set(SeqId.value()); |
| while (It != EndIt) { |
| auto ItSeq = Graph.getInstrSequenceId(&*It); |
| if (!ItSeq.has_value()) |
| break; |
| if (ItSeq.value() != SeqId.value()) |
| break; |
| ++It; |
| } |
| auto &SeqNode = Graph.getNodeFromId(SeqId.value()); |
| for (auto *P : SeqNode.getPushes()) { |
| P->moveBefore(InsertBefore); |
| Deps.insert(P); |
| } |
| for (auto *P : SeqNode.getPops()) { |
| P->moveBefore(InsertBefore); |
| Deps.insert(P); |
| } |
| MovedInstructions += SeqNode.getPushes().size(); |
| MovedInstructions += SeqNode.getPops().size(); |
| ++MovedNodesCount; |
| InsertBefore = SeqNode.getPushes()[0]; |
| return; |
| } |
| I->moveBefore(InsertBefore); |
| InsertBefore = I; |
| Deps.insert(I); |
| ++MovedInstructions; |
| }; |
| unsigned N1Id = N1.getId(); |
| unsigned N2Id = N2.getId(); |
| // Move instruction for merging. |
| while (It != EndIt) { |
| Instruction *I = (&*It++); |
| auto SeqId = Graph.getInstrSequenceId(I); |
| // Do not move instructions of the two nodes we are merging. |
| if (SeqId.has_value() && (SeqId.value() == N1Id || SeqId.value() == N2Id)) |
| continue; |
| bool AddedToDeps = false; |
| // If any of the users of the current instruction is one of the instructions |
| // we determined we have to move then move this instruction as well. |
| for (auto *U : I->users()) { |
| Instruction *UI = cast<Instruction>(U); |
| if (Deps.count(UI)) { |
| AddedToDeps = true; |
| assert( |
| (!Graph.getDepForInstr(I) || |
| !Graph.getDepForInstr(I)->test(N1.getId())) && |
| "Adding dep that make us dependent on node we want to merge with"); |
| MoveInstruction(I, SeqId); |
| break; |
| } |
| } |
| // Check for dependencies with previously moved instructions through |
| // memory. |
| if (!I->mayReadOrWriteMemory()) |
| continue; |
| // We don't consider handled XLU operations as aliasing with other |
| // inaccessible mem intrinsics. |
| if (SeqId.has_value()) |
| continue; |
| // If this instruction doesn't alias then we are good and we don't need |
| // to move it otherwise move it and track it. If the instruction is already |
| // in Deps then we don't need to move it. Just track it. |
| if (!AddedToDeps && TAT.aliasQuery(I, false) != AliasResult::NoAlias) { |
| assert((!Graph.getDepForInstr(I) || |
| !Graph.getDepForInstr(I)->test(N1.getId())) && |
| "Adding dep that make us dependent on node we want to merge with"); |
| MoveInstruction(I, SeqId); |
| } |
| if (Deps.count(I)) |
| TAT.aliasQuery(I, true); |
| } |
| LLVM_DEBUG(dbgs() << "Moved instructions: " << MovedInstructions << "\n"); |
| assert(N1.getPushes().size() == N2.getPushes().size()); |
| // Merge the two XLU transpose sequences. |
| IRBuilder<> Builder(BB.getContext()); |
| auto ForceToFloat = [&Builder](Value *V) { |
| assert(V->getType()->isVectorTy()); |
| Value *Float = V; |
| if (!Float->getType()->isFPOrFPVectorTy()) { |
| Float = Builder.CreateBitCast( |
| Float, VectorType::get( |
| Builder.getFloatTy(), |
| cast<VectorType>(Float->getType())->getElementCount())); |
| } |
| return Float; |
| }; |
| auto ForceToTypeOf = [&Builder](Value *V, Value *Of) { |
| assert(V->getType()->isVectorTy()); |
| assert(Of->getType()->isVectorTy()); |
| Value *Result = V; |
| if (Result->getType() != Of->getType()) { |
| Result = Builder.CreateBitCast(Result, Of->getType()); |
| } |
| return Result; |
| }; |
| Type *FloatVecTy = VectorType::get(Builder.getFloatTy(), 1024, false); |
| Function *PackIntr = llvm::Intrinsic::getDeclaration( |
| BB.getModule(), llvm::Intrinsic::tpu_pack, {FloatVecTy}); |
| for (int I = N1.getPushes().size() - 1; I >= 0; --I) { |
| Instruction *Push1 = N1.getPushes()[I]; |
| Instruction *Push2 = N2.getPushes()[I]; |
| Builder.SetInsertPoint(N1.getPushes()[I]); |
| Instruction *Pack = |
| Builder.CreateCall(PackIntr, {ForceToFloat(Push2->getOperand(0)), |
| ForceToFloat(Push1->getOperand(0))}); |
| Push1->replaceUsesOfWith(Push1->getOperand(0), |
| ForceToTypeOf(Pack, Push1->getOperand(0))); |
| Instruction *Pop1 = N1.getPops()[I]; |
| Instruction *Pop2 = N2.getPops()[I]; |
| Builder.SetInsertPoint(&BB, std::next(Pop1->getIterator())); |
| Instruction *UnpackL = cast<Instruction>(Builder.CreateShl( |
| Pop1, ConstantDataVector::getSplat(1024, Builder.getInt32(16)))); |
| Instruction *UnpackU = cast<Instruction>(Builder.CreateAnd( |
| Pop1, |
| ConstantDataVector::getSplat(1024, Builder.getInt32(0xFFFF0000U)))); |
| Pop1->replaceUsesWithIf(UnpackL, [UnpackL, UnpackU](Use &U) { |
| return U.getUser() != UnpackL && U.getUser() != UnpackU; |
| }); |
| Pop2->replaceUsesWithIf( |
| UnpackU, [UnpackU](Use &U) { return U.getUser() != UnpackU; }); |
| Pop2->eraseFromParent(); |
| } |
| for (int I = N2.getPushes().size() - 1; I >= 0; --I) { |
| N2.getPushes()[I]->eraseFromParent(); |
| } |
| auto NewIt = N1It; |
| if (MovedNodesCount > 0) { |
| NewIt = std::stable_partition( |
| N1It, N2It, [&](const Node &N) { return MovedNodes.test(N.getId()); }); |
| } |
| LLVM_DEBUG(dbgs() << "Post reorder: "; for (auto &N |
| : Graph) { |
| dbgs() << N.getId() << " "; |
| } dbgs() << "\n";); |
| return NewIt; |
| } |
| |
| bool TPUXLUBF16Merger::run() { |
| Graph.buildGraph(BB); |
| if (Graph.begin() == Graph.end()) |
| return false; |
| auto It = Graph.begin(); |
| auto E = std::prev(Graph.end()); |
| BitVector Merged(Graph.size()); |
| unsigned MergedNodes = 0; |
| bool Changed = false; |
| // We currently naively iterate the graph in node order and try to merge |
| // eagerly with the first node we find |
| // TODO(maggioni): Investigate fancier algorithms. |
| while (It != E) { |
| auto &N = *It; |
| // If the node cannot be merged (preserve precision) or has been merged |
| // already continue. |
| if (!N.canMerge() || Merged.test(N.getId())) { |
| ++It; |
| continue; |
| } |
| // Helper to evaluate if a node can be merged with the current node. |
| // Check if the node ToMerge can be merged, if it is of the same type of |
| // the current node, if it hasn't been merged already and if it is not one |
| // of the successors of the current node. |
| auto ValidForMerge = [&](const Node &ToMerge) { |
| return !Merged.test(ToMerge.getId()) && N.canMergeWith(ToMerge); |
| }; |
| auto ToMergeIt = std::find_if(std::next(It), Graph.end(), ValidForMerge); |
| if (ToMergeIt == Graph.end()) { |
| ++It; |
| // Didn't find anything to merge with |
| continue; |
| } |
| // Perform the merge |
| ++MergedNodes; |
| LLVM_DEBUG(dbgs() << "Merging Node " << N.getId() << " with node " |
| << ToMergeIt->getId() << "\n"); |
| Merged.set(N.getId()); |
| Merged.set(ToMergeIt->getId()); |
| It = mergeNodes(It, ToMergeIt); |
| Changed = true; |
| } |
| (void)MergedNodes; |
| LLVM_DEBUG(dbgs() << "Num of merged nodes: " << MergedNodes << "\n"); |
| |
| return Changed; |
| } |
| |
| class TPUXLUOptimizations : public FunctionPass { |
| public: |
| static char ID; |
| TPUXLUOptimizations() : FunctionPass(ID) {} |
| |
| bool runOnFunction(Function &F) override; |
| void getAnalysisUsage(AnalysisUsage &AU) const override { |
| AU.setPreservesCFG(); |
| } |
| |
| StringRef getPassName() const override { return "TPU XLU optimizations"; } |
| |
| private: |
| bool processBasicBlock(BasicBlock &BB); |
| }; |
| char TPUXLUOptimizations::ID = 0; |
| |
| } // namespace |
| |
| INITIALIZE_PASS(TPUXLUOptimizations, DEBUG_TYPE, "TPU XLU optimizations", false, |
| false) |
| |
| Pass *llvm::createTPUXLUOptimizationsPass() { |
| return new TPUXLUOptimizations(); |
| } |
| |
| bool TPUXLUOptimizations::processBasicBlock(BasicBlock &BB) { |
| return TPUXLUBF16Merger(BB).run(); |
| } |
| |
| bool TPUXLUOptimizations::runOnFunction(Function &F) { |
| bool Changed = false; |
| for (auto &BB : F) { |
| Changed |= processBasicBlock(BB); |
| } |
| // After pass verification of valid IR. |
| assert(!verifyFunction(F, &dbgs())); |
| return Changed; |
| } |