1 //===- Region.cpp ---------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "llvm/SandboxIR/Region.h"
10 #include "llvm/SandboxIR/Function.h"
11
12 namespace llvm::sandboxir {
13
getCost(Instruction * I) const14 InstructionCost ScoreBoard::getCost(Instruction *I) const {
15 auto *LLVMI = cast<llvm::Instruction>(I->Val);
16 SmallVector<const llvm::Value *> Operands(LLVMI->operands());
17 return TTI.getInstructionCost(LLVMI, Operands, CostKind);
18 }
19
remove(Instruction * I)20 void ScoreBoard::remove(Instruction *I) {
21 auto Cost = getCost(I);
22 if (Rgn.contains(I))
23 // If `I` is one the newly added ones, then we should adjust `AfterCost`
24 AfterCost -= Cost;
25 else
26 // If `I` is one of the original instructions (outside the region) then it
27 // is part of the original code, so adjust `BeforeCost`.
28 BeforeCost += Cost;
29 }
30
31 #ifndef NDEBUG
dump() const32 void ScoreBoard::dump() const { dump(dbgs()); }
33 #endif
34
Region(Context & Ctx,TargetTransformInfo & TTI)35 Region::Region(Context &Ctx, TargetTransformInfo &TTI)
36 : Ctx(Ctx), Scoreboard(*this, TTI) {
37 LLVMContext &LLVMCtx = Ctx.LLVMCtx;
38 auto *RegionStrMD = MDString::get(LLVMCtx, RegionStr);
39 RegionMDN = MDNode::getDistinct(LLVMCtx, {RegionStrMD});
40
41 CreateInstCB = Ctx.registerCreateInstrCallback(
42 [this](Instruction *NewInst) { add(NewInst); });
43 EraseInstCB = Ctx.registerEraseInstrCallback([this](Instruction *ErasedInst) {
44 remove(ErasedInst);
45 removeFromAux(ErasedInst);
46 });
47 }
48
~Region()49 Region::~Region() {
50 Ctx.unregisterCreateInstrCallback(CreateInstCB);
51 Ctx.unregisterEraseInstrCallback(EraseInstCB);
52 }
53
addImpl(Instruction * I,bool IgnoreCost)54 void Region::addImpl(Instruction *I, bool IgnoreCost) {
55 Insts.insert(I);
56 // TODO: Consider tagging instructions lazily.
57 cast<llvm::Instruction>(I->Val)->setMetadata(MDKind, RegionMDN);
58 if (!IgnoreCost)
59 // Keep track of the instruction cost.
60 Scoreboard.add(I);
61 }
62
setAux(ArrayRef<Instruction * > Aux)63 void Region::setAux(ArrayRef<Instruction *> Aux) {
64 this->Aux = SmallVector<Instruction *>(Aux);
65 auto &LLVMCtx = Ctx.LLVMCtx;
66 for (auto [Idx, I] : enumerate(Aux)) {
67 llvm::ConstantInt *IdxC =
68 llvm::ConstantInt::get(llvm::Type::getInt32Ty(LLVMCtx), Idx, false);
69 assert(cast<llvm::Instruction>(I->Val)->getMetadata(AuxMDKind) == nullptr &&
70 "Instruction already in Aux!");
71 cast<llvm::Instruction>(I->Val)->setMetadata(
72 AuxMDKind, MDNode::get(LLVMCtx, ConstantAsMetadata::get(IdxC)));
73 // Aux instrs should always be in a region.
74 addImpl(I, /*DontTrackCost=*/true);
75 }
76 }
77
setAux(unsigned Idx,Instruction * I)78 void Region::setAux(unsigned Idx, Instruction *I) {
79 assert((Idx >= Aux.size() || Aux[Idx] == nullptr) &&
80 "There is already an Instruction at Idx in Aux!");
81 unsigned ExpectedSz = Idx + 1;
82 if (Aux.size() < ExpectedSz) {
83 auto SzBefore = Aux.size();
84 Aux.resize(ExpectedSz);
85 // Initialize the gap with nullptr.
86 for (unsigned Idx = SzBefore; Idx + 1 < ExpectedSz; ++Idx)
87 Aux[Idx] = nullptr;
88 }
89 Aux[Idx] = I;
90 // Aux instrs should always be in a region.
91 addImpl(I, /*DontTrackCost=*/true);
92 }
93
dropAuxMetadata(Instruction * I)94 void Region::dropAuxMetadata(Instruction *I) {
95 auto *LLVMI = cast<llvm::Instruction>(I->Val);
96 LLVMI->setMetadata(AuxMDKind, nullptr);
97 }
98
removeFromAux(Instruction * I)99 void Region::removeFromAux(Instruction *I) {
100 auto It = find(Aux, I);
101 if (It == Aux.end())
102 return;
103 dropAuxMetadata(I);
104 Aux.erase(It);
105 }
106
clearAux()107 void Region::clearAux() {
108 for (unsigned Idx : seq<unsigned>(0, Aux.size()))
109 dropAuxMetadata(Aux[Idx]);
110 Aux.clear();
111 }
112
remove(Instruction * I)113 void Region::remove(Instruction *I) {
114 // Keep track of the instruction cost. This need to be done *before* we remove
115 // `I` from the region.
116 Scoreboard.remove(I);
117
118 Insts.remove(I);
119 cast<llvm::Instruction>(I->Val)->setMetadata(MDKind, nullptr);
120 }
121
122 #ifndef NDEBUG
operator ==(const Region & Other) const123 bool Region::operator==(const Region &Other) const {
124 if (Insts.size() != Other.Insts.size())
125 return false;
126 if (!std::is_permutation(Insts.begin(), Insts.end(), Other.Insts.begin()))
127 return false;
128 return true;
129 }
130
dump(raw_ostream & OS) const131 void Region::dump(raw_ostream &OS) const {
132 for (auto *I : Insts)
133 OS << *I << "\n";
134 if (!Aux.empty()) {
135 OS << "\nAux:\n";
136 for (auto *I : Aux) {
137 if (I == nullptr)
138 OS << "NULL\n";
139 else
140 OS << *I << "\n";
141 }
142 }
143 }
144
dump() const145 void Region::dump() const {
146 dump(dbgs());
147 dbgs() << "\n";
148 }
149 #endif // NDEBUG
150
151 SmallVector<std::unique_ptr<Region>>
createRegionsFromMD(Function & F,TargetTransformInfo & TTI)152 Region::createRegionsFromMD(Function &F, TargetTransformInfo &TTI) {
153 SmallVector<std::unique_ptr<Region>> Regions;
154 DenseMap<MDNode *, Region *> MDNToRegion;
155 auto &Ctx = F.getContext();
156 for (BasicBlock &BB : F) {
157 for (Instruction &Inst : BB) {
158 auto *LLVMI = cast<llvm::Instruction>(Inst.Val);
159 Region *R = nullptr;
160 if (auto *MDN = LLVMI->getMetadata(MDKind)) {
161 auto [It, Inserted] = MDNToRegion.try_emplace(MDN);
162 if (Inserted) {
163 Regions.push_back(std::make_unique<Region>(Ctx, TTI));
164 R = Regions.back().get();
165 It->second = R;
166 } else {
167 R = It->second;
168 }
169 R->addImpl(&Inst, /*IgnoreCost=*/true);
170 }
171 if (auto *AuxMDN = LLVMI->getMetadata(AuxMDKind)) {
172 llvm::Constant *IdxC =
173 dyn_cast<ConstantAsMetadata>(AuxMDN->getOperand(0))->getValue();
174 auto Idx = cast<llvm::ConstantInt>(IdxC)->getSExtValue();
175 if (R == nullptr) {
176 errs() << "No region specified for Aux: '" << *LLVMI << "'\n";
177 exit(1);
178 }
179 R->setAux(Idx, &Inst);
180 }
181 }
182 }
183 #ifndef NDEBUG
184 // Check that there are no gaps in the Aux vector.
185 for (auto &RPtr : Regions)
186 for (auto *I : RPtr->getAux())
187 assert(I != nullptr && "Gap in Aux!");
188 #endif
189 return Regions;
190 }
191
192 } // namespace llvm::sandboxir
193