1 //===- LoopVersioning.cpp - Utility to version a loop ---------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines a utility class to perform loop versioning. The versioned 10 // loop speculates that otherwise may-aliasing memory accesses don't overlap and 11 // emits checks to prove this. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/Utils/LoopVersioning.h" 16 #include "llvm/Analysis/LoopAccessAnalysis.h" 17 #include "llvm/Analysis/LoopInfo.h" 18 #include "llvm/Analysis/ScalarEvolutionExpander.h" 19 #include "llvm/IR/Dominators.h" 20 #include "llvm/IR/MDBuilder.h" 21 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 22 #include "llvm/Transforms/Utils/Cloning.h" 23 24 using namespace llvm; 25 26 static cl::opt<bool> 27 AnnotateNoAlias("loop-version-annotate-no-alias", cl::init(true), 28 cl::Hidden, 29 cl::desc("Add no-alias annotation for instructions that " 30 "are disambiguated by memchecks")); 31 32 LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI, Loop *L, LoopInfo *LI, 33 DominatorTree *DT, ScalarEvolution *SE, 34 bool UseLAIChecks) 35 : VersionedLoop(L), NonVersionedLoop(nullptr), LAI(LAI), LI(LI), DT(DT), 36 SE(SE) { 37 assert(L->getExitBlock() && "No single exit block"); 38 assert(L->isLoopSimplifyForm() && "Loop is not in loop-simplify form"); 39 if (UseLAIChecks) { 40 setAliasChecks(LAI.getRuntimePointerChecking()->getChecks()); 41 setSCEVChecks(LAI.getPSE().getUnionPredicate()); 42 } 43 } 44 45 void LoopVersioning::setAliasChecks( 46 SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks) { 47 AliasChecks = std::move(Checks); 48 } 49 50 void LoopVersioning::setSCEVChecks(SCEVUnionPredicate Check) { 51 Preds = std::move(Check); 52 } 53 54 void LoopVersioning::versionLoop( 55 const SmallVectorImpl<Instruction *> &DefsUsedOutside) { 56 Instruction *FirstCheckInst; 57 Instruction *MemRuntimeCheck; 58 Value *SCEVRuntimeCheck; 59 Value *RuntimeCheck = nullptr; 60 61 // Add the memcheck in the original preheader (this is empty initially). 62 BasicBlock *RuntimeCheckBB = VersionedLoop->getLoopPreheader(); 63 std::tie(FirstCheckInst, MemRuntimeCheck) = 64 LAI.addRuntimeChecks(RuntimeCheckBB->getTerminator(), AliasChecks); 65 66 const SCEVUnionPredicate &Pred = LAI.getPSE().getUnionPredicate(); 67 SCEVExpander Exp(*SE, RuntimeCheckBB->getModule()->getDataLayout(), 68 "scev.check"); 69 SCEVRuntimeCheck = 70 Exp.expandCodeForPredicate(&Pred, RuntimeCheckBB->getTerminator()); 71 auto *CI = dyn_cast<ConstantInt>(SCEVRuntimeCheck); 72 73 // Discard the SCEV runtime check if it is always true. 74 if (CI && CI->isZero()) 75 SCEVRuntimeCheck = nullptr; 76 77 if (MemRuntimeCheck && SCEVRuntimeCheck) { 78 RuntimeCheck = BinaryOperator::Create(Instruction::Or, MemRuntimeCheck, 79 SCEVRuntimeCheck, "lver.safe"); 80 if (auto *I = dyn_cast<Instruction>(RuntimeCheck)) 81 I->insertBefore(RuntimeCheckBB->getTerminator()); 82 } else 83 RuntimeCheck = MemRuntimeCheck ? MemRuntimeCheck : SCEVRuntimeCheck; 84 85 assert(RuntimeCheck && "called even though we don't need " 86 "any runtime checks"); 87 88 // Rename the block to make the IR more readable. 89 RuntimeCheckBB->setName(VersionedLoop->getHeader()->getName() + 90 ".lver.check"); 91 92 // Create empty preheader for the loop (and after cloning for the 93 // non-versioned loop). 94 BasicBlock *PH = 95 SplitBlock(RuntimeCheckBB, RuntimeCheckBB->getTerminator(), DT, LI); 96 PH->setName(VersionedLoop->getHeader()->getName() + ".ph"); 97 98 // Clone the loop including the preheader. 99 // 100 // FIXME: This does not currently preserve SimplifyLoop because the exit 101 // block is a join between the two loops. 102 SmallVector<BasicBlock *, 8> NonVersionedLoopBlocks; 103 NonVersionedLoop = 104 cloneLoopWithPreheader(PH, RuntimeCheckBB, VersionedLoop, VMap, 105 ".lver.orig", LI, DT, NonVersionedLoopBlocks); 106 remapInstructionsInBlocks(NonVersionedLoopBlocks, VMap); 107 108 // Insert the conditional branch based on the result of the memchecks. 109 Instruction *OrigTerm = RuntimeCheckBB->getTerminator(); 110 BranchInst::Create(NonVersionedLoop->getLoopPreheader(), 111 VersionedLoop->getLoopPreheader(), RuntimeCheck, OrigTerm); 112 OrigTerm->eraseFromParent(); 113 114 // The loops merge in the original exit block. This is now dominated by the 115 // memchecking block. 116 DT->changeImmediateDominator(VersionedLoop->getExitBlock(), RuntimeCheckBB); 117 118 // Adds the necessary PHI nodes for the versioned loops based on the 119 // loop-defined values used outside of the loop. 120 addPHINodes(DefsUsedOutside); 121 } 122 123 void LoopVersioning::addPHINodes( 124 const SmallVectorImpl<Instruction *> &DefsUsedOutside) { 125 BasicBlock *PHIBlock = VersionedLoop->getExitBlock(); 126 assert(PHIBlock && "No single successor to loop exit block"); 127 PHINode *PN; 128 129 // First add a single-operand PHI for each DefsUsedOutside if one does not 130 // exists yet. 131 for (auto *Inst : DefsUsedOutside) { 132 // See if we have a single-operand PHI with the value defined by the 133 // original loop. 134 for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) { 135 if (PN->getIncomingValue(0) == Inst) 136 break; 137 } 138 // If not create it. 139 if (!PN) { 140 PN = PHINode::Create(Inst->getType(), 2, Inst->getName() + ".lver", 141 &PHIBlock->front()); 142 SmallVector<User*, 8> UsersToUpdate; 143 for (User *U : Inst->users()) 144 if (!VersionedLoop->contains(cast<Instruction>(U)->getParent())) 145 UsersToUpdate.push_back(U); 146 for (User *U : UsersToUpdate) 147 U->replaceUsesOfWith(Inst, PN); 148 PN->addIncoming(Inst, VersionedLoop->getExitingBlock()); 149 } 150 } 151 152 // Then for each PHI add the operand for the edge from the cloned loop. 153 for (auto I = PHIBlock->begin(); (PN = dyn_cast<PHINode>(I)); ++I) { 154 assert(PN->getNumOperands() == 1 && 155 "Exit block should only have on predecessor"); 156 157 // If the definition was cloned used that otherwise use the same value. 158 Value *ClonedValue = PN->getIncomingValue(0); 159 auto Mapped = VMap.find(ClonedValue); 160 if (Mapped != VMap.end()) 161 ClonedValue = Mapped->second; 162 163 PN->addIncoming(ClonedValue, NonVersionedLoop->getExitingBlock()); 164 } 165 } 166 167 void LoopVersioning::prepareNoAliasMetadata() { 168 // We need to turn the no-alias relation between pointer checking groups into 169 // no-aliasing annotations between instructions. 170 // 171 // We accomplish this by mapping each pointer checking group (a set of 172 // pointers memchecked together) to an alias scope and then also mapping each 173 // group to the list of scopes it can't alias. 174 175 const RuntimePointerChecking *RtPtrChecking = LAI.getRuntimePointerChecking(); 176 LLVMContext &Context = VersionedLoop->getHeader()->getContext(); 177 178 // First allocate an aliasing scope for each pointer checking group. 179 // 180 // While traversing through the checking groups in the loop, also create a 181 // reverse map from pointers to the pointer checking group they were assigned 182 // to. 183 MDBuilder MDB(Context); 184 MDNode *Domain = MDB.createAnonymousAliasScopeDomain("LVerDomain"); 185 186 for (const auto &Group : RtPtrChecking->CheckingGroups) { 187 GroupToScope[&Group] = MDB.createAnonymousAliasScope(Domain); 188 189 for (unsigned PtrIdx : Group.Members) 190 PtrToGroup[RtPtrChecking->getPointerInfo(PtrIdx).PointerValue] = &Group; 191 } 192 193 // Go through the checks and for each pointer group, collect the scopes for 194 // each non-aliasing pointer group. 195 DenseMap<const RuntimePointerChecking::CheckingPtrGroup *, 196 SmallVector<Metadata *, 4>> 197 GroupToNonAliasingScopes; 198 199 for (const auto &Check : AliasChecks) 200 GroupToNonAliasingScopes[Check.first].push_back(GroupToScope[Check.second]); 201 202 // Finally, transform the above to actually map to scope list which is what 203 // the metadata uses. 204 205 for (auto Pair : GroupToNonAliasingScopes) 206 GroupToNonAliasingScopeList[Pair.first] = MDNode::get(Context, Pair.second); 207 } 208 209 void LoopVersioning::annotateLoopWithNoAlias() { 210 if (!AnnotateNoAlias) 211 return; 212 213 // First prepare the maps. 214 prepareNoAliasMetadata(); 215 216 // Add the scope and no-alias metadata to the instructions. 217 for (Instruction *I : LAI.getDepChecker().getMemoryInstructions()) { 218 annotateInstWithNoAlias(I); 219 } 220 } 221 222 void LoopVersioning::annotateInstWithNoAlias(Instruction *VersionedInst, 223 const Instruction *OrigInst) { 224 if (!AnnotateNoAlias) 225 return; 226 227 LLVMContext &Context = VersionedLoop->getHeader()->getContext(); 228 const Value *Ptr = isa<LoadInst>(OrigInst) 229 ? cast<LoadInst>(OrigInst)->getPointerOperand() 230 : cast<StoreInst>(OrigInst)->getPointerOperand(); 231 232 // Find the group for the pointer and then add the scope metadata. 233 auto Group = PtrToGroup.find(Ptr); 234 if (Group != PtrToGroup.end()) { 235 VersionedInst->setMetadata( 236 LLVMContext::MD_alias_scope, 237 MDNode::concatenate( 238 VersionedInst->getMetadata(LLVMContext::MD_alias_scope), 239 MDNode::get(Context, GroupToScope[Group->second]))); 240 241 // Add the no-alias metadata. 242 auto NonAliasingScopeList = GroupToNonAliasingScopeList.find(Group->second); 243 if (NonAliasingScopeList != GroupToNonAliasingScopeList.end()) 244 VersionedInst->setMetadata( 245 LLVMContext::MD_noalias, 246 MDNode::concatenate( 247 VersionedInst->getMetadata(LLVMContext::MD_noalias), 248 NonAliasingScopeList->second)); 249 } 250 } 251 252 namespace { 253 /// Also expose this is a pass. Currently this is only used for 254 /// unit-testing. It adds all memchecks necessary to remove all may-aliasing 255 /// array accesses from the loop. 256 class LoopVersioningPass : public FunctionPass { 257 public: 258 LoopVersioningPass() : FunctionPass(ID) { 259 initializeLoopVersioningPassPass(*PassRegistry::getPassRegistry()); 260 } 261 262 bool runOnFunction(Function &F) override { 263 auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 264 auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); 265 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 266 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 267 268 // Build up a worklist of inner-loops to version. This is necessary as the 269 // act of versioning a loop creates new loops and can invalidate iterators 270 // across the loops. 271 SmallVector<Loop *, 8> Worklist; 272 273 for (Loop *TopLevelLoop : *LI) 274 for (Loop *L : depth_first(TopLevelLoop)) 275 // We only handle inner-most loops. 276 if (L->empty()) 277 Worklist.push_back(L); 278 279 // Now walk the identified inner loops. 280 bool Changed = false; 281 for (Loop *L : Worklist) { 282 const LoopAccessInfo &LAI = LAA->getInfo(L); 283 if (L->isLoopSimplifyForm() && !LAI.hasConvergentOp() && 284 (LAI.getNumRuntimePointerChecks() || 285 !LAI.getPSE().getUnionPredicate().isAlwaysTrue())) { 286 LoopVersioning LVer(LAI, L, LI, DT, SE); 287 LVer.versionLoop(); 288 LVer.annotateLoopWithNoAlias(); 289 Changed = true; 290 } 291 } 292 293 return Changed; 294 } 295 296 void getAnalysisUsage(AnalysisUsage &AU) const override { 297 AU.addRequired<LoopInfoWrapperPass>(); 298 AU.addPreserved<LoopInfoWrapperPass>(); 299 AU.addRequired<LoopAccessLegacyAnalysis>(); 300 AU.addRequired<DominatorTreeWrapperPass>(); 301 AU.addPreserved<DominatorTreeWrapperPass>(); 302 AU.addRequired<ScalarEvolutionWrapperPass>(); 303 } 304 305 static char ID; 306 }; 307 } 308 309 #define LVER_OPTION "loop-versioning" 310 #define DEBUG_TYPE LVER_OPTION 311 312 char LoopVersioningPass::ID; 313 static const char LVer_name[] = "Loop Versioning"; 314 315 INITIALIZE_PASS_BEGIN(LoopVersioningPass, LVER_OPTION, LVer_name, false, false) 316 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 317 INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) 318 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 319 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) 320 INITIALIZE_PASS_END(LoopVersioningPass, LVER_OPTION, LVer_name, false, false) 321 322 namespace llvm { 323 FunctionPass *createLoopVersioningPass() { 324 return new LoopVersioningPass(); 325 } 326 } 327