1 //===- HexagonLoopIdiomRecognition.cpp ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "HexagonLoopIdiomRecognition.h" 10 #include "llvm/ADT/APInt.h" 11 #include "llvm/ADT/DenseMap.h" 12 #include "llvm/ADT/SetVector.h" 13 #include "llvm/ADT/SmallPtrSet.h" 14 #include "llvm/ADT/SmallSet.h" 15 #include "llvm/ADT/SmallVector.h" 16 #include "llvm/ADT/StringRef.h" 17 #include "llvm/Analysis/AliasAnalysis.h" 18 #include "llvm/Analysis/InstructionSimplify.h" 19 #include "llvm/Analysis/LoopAnalysisManager.h" 20 #include "llvm/Analysis/LoopInfo.h" 21 #include "llvm/Analysis/LoopPass.h" 22 #include "llvm/Analysis/MemoryLocation.h" 23 #include "llvm/Analysis/ScalarEvolution.h" 24 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 25 #include "llvm/Analysis/TargetLibraryInfo.h" 26 #include "llvm/Analysis/ValueTracking.h" 27 #include "llvm/IR/Attributes.h" 28 #include "llvm/IR/BasicBlock.h" 29 #include "llvm/IR/Constant.h" 30 #include "llvm/IR/Constants.h" 31 #include "llvm/IR/DataLayout.h" 32 #include "llvm/IR/DebugLoc.h" 33 #include "llvm/IR/DerivedTypes.h" 34 #include "llvm/IR/Dominators.h" 35 #include "llvm/IR/Function.h" 36 #include "llvm/IR/IRBuilder.h" 37 #include "llvm/IR/InstrTypes.h" 38 #include "llvm/IR/Instruction.h" 39 #include "llvm/IR/Instructions.h" 40 #include "llvm/IR/IntrinsicInst.h" 41 #include "llvm/IR/Intrinsics.h" 42 #include "llvm/IR/IntrinsicsHexagon.h" 43 #include "llvm/IR/Module.h" 44 #include "llvm/IR/PassManager.h" 45 #include "llvm/IR/PatternMatch.h" 46 #include "llvm/IR/Type.h" 47 #include "llvm/IR/User.h" 48 #include "llvm/IR/Value.h" 49 #include "llvm/InitializePasses.h" 50 #include "llvm/Pass.h" 51 #include "llvm/Support/Casting.h" 52 #include "llvm/Support/CommandLine.h" 53 #include "llvm/Support/Compiler.h" 54 #include "llvm/Support/Debug.h" 55 #include "llvm/Support/ErrorHandling.h" 56 #include "llvm/Support/KnownBits.h" 57 #include "llvm/Support/raw_ostream.h" 58 #include "llvm/TargetParser/Triple.h" 59 #include "llvm/Transforms/Scalar.h" 60 #include "llvm/Transforms/Utils.h" 61 #include "llvm/Transforms/Utils/Local.h" 62 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 63 #include <algorithm> 64 #include <array> 65 #include <cassert> 66 #include <cstdint> 67 #include <cstdlib> 68 #include <deque> 69 #include <functional> 70 #include <iterator> 71 #include <map> 72 #include <set> 73 #include <utility> 74 #include <vector> 75 76 #define DEBUG_TYPE "hexagon-lir" 77 78 using namespace llvm; 79 80 static cl::opt<bool> DisableMemcpyIdiom("disable-memcpy-idiom", 81 cl::Hidden, cl::init(false), 82 cl::desc("Disable generation of memcpy in loop idiom recognition")); 83 84 static cl::opt<bool> DisableMemmoveIdiom("disable-memmove-idiom", 85 cl::Hidden, cl::init(false), 86 cl::desc("Disable generation of memmove in loop idiom recognition")); 87 88 static cl::opt<unsigned> RuntimeMemSizeThreshold("runtime-mem-idiom-threshold", 89 cl::Hidden, cl::init(0), cl::desc("Threshold (in bytes) for the runtime " 90 "check guarding the memmove.")); 91 92 static cl::opt<unsigned> CompileTimeMemSizeThreshold( 93 "compile-time-mem-idiom-threshold", cl::Hidden, cl::init(64), 94 cl::desc("Threshold (in bytes) to perform the transformation, if the " 95 "runtime loop count (mem transfer size) is known at compile-time.")); 96 97 static cl::opt<bool> OnlyNonNestedMemmove("only-nonnested-memmove-idiom", 98 cl::Hidden, cl::init(true), 99 cl::desc("Only enable generating memmove in non-nested loops")); 100 101 static cl::opt<bool> HexagonVolatileMemcpy( 102 "disable-hexagon-volatile-memcpy", cl::Hidden, cl::init(false), 103 cl::desc("Enable Hexagon-specific memcpy for volatile destination.")); 104 105 static cl::opt<unsigned> SimplifyLimit("hlir-simplify-limit", cl::init(10000), 106 cl::Hidden, cl::desc("Maximum number of simplification steps in HLIR")); 107 108 static const char *HexagonVolatileMemcpyName 109 = "hexagon_memcpy_forward_vp4cp4n2"; 110 111 112 namespace llvm { 113 114 void initializeHexagonLoopIdiomRecognizeLegacyPassPass(PassRegistry &); 115 Pass *createHexagonLoopIdiomPass(); 116 117 } // end namespace llvm 118 119 namespace { 120 121 class HexagonLoopIdiomRecognize { 122 public: 123 explicit HexagonLoopIdiomRecognize(AliasAnalysis *AA, DominatorTree *DT, 124 LoopInfo *LF, const TargetLibraryInfo *TLI, 125 ScalarEvolution *SE) 126 : AA(AA), DT(DT), LF(LF), TLI(TLI), SE(SE) {} 127 128 bool run(Loop *L); 129 130 private: 131 int getSCEVStride(const SCEVAddRecExpr *StoreEv); 132 bool isLegalStore(Loop *CurLoop, StoreInst *SI); 133 void collectStores(Loop *CurLoop, BasicBlock *BB, 134 SmallVectorImpl<StoreInst *> &Stores); 135 bool processCopyingStore(Loop *CurLoop, StoreInst *SI, const SCEV *BECount); 136 bool coverLoop(Loop *L, SmallVectorImpl<Instruction *> &Insts) const; 137 bool runOnLoopBlock(Loop *CurLoop, BasicBlock *BB, const SCEV *BECount, 138 SmallVectorImpl<BasicBlock *> &ExitBlocks); 139 bool runOnCountableLoop(Loop *L); 140 141 AliasAnalysis *AA; 142 const DataLayout *DL; 143 DominatorTree *DT; 144 LoopInfo *LF; 145 const TargetLibraryInfo *TLI; 146 ScalarEvolution *SE; 147 bool HasMemcpy, HasMemmove; 148 }; 149 150 class HexagonLoopIdiomRecognizeLegacyPass : public LoopPass { 151 public: 152 static char ID; 153 154 explicit HexagonLoopIdiomRecognizeLegacyPass() : LoopPass(ID) { 155 initializeHexagonLoopIdiomRecognizeLegacyPassPass( 156 *PassRegistry::getPassRegistry()); 157 } 158 159 StringRef getPassName() const override { 160 return "Recognize Hexagon-specific loop idioms"; 161 } 162 163 void getAnalysisUsage(AnalysisUsage &AU) const override { 164 AU.addRequired<LoopInfoWrapperPass>(); 165 AU.addRequiredID(LoopSimplifyID); 166 AU.addRequiredID(LCSSAID); 167 AU.addRequired<AAResultsWrapperPass>(); 168 AU.addRequired<ScalarEvolutionWrapperPass>(); 169 AU.addRequired<DominatorTreeWrapperPass>(); 170 AU.addRequired<TargetLibraryInfoWrapperPass>(); 171 AU.addPreserved<TargetLibraryInfoWrapperPass>(); 172 } 173 174 bool runOnLoop(Loop *L, LPPassManager &LPM) override; 175 }; 176 177 struct Simplifier { 178 struct Rule { 179 using FuncType = std::function<Value *(Instruction *, LLVMContext &)>; 180 Rule(StringRef N, FuncType F) : Name(N), Fn(F) {} 181 StringRef Name; // For debugging. 182 FuncType Fn; 183 }; 184 185 void addRule(StringRef N, const Rule::FuncType &F) { 186 Rules.push_back(Rule(N, F)); 187 } 188 189 private: 190 struct WorkListType { 191 WorkListType() = default; 192 193 void push_back(Value *V) { 194 // Do not push back duplicates. 195 if (S.insert(V).second) 196 Q.push_back(V); 197 } 198 199 Value *pop_front_val() { 200 Value *V = Q.front(); 201 Q.pop_front(); 202 S.erase(V); 203 return V; 204 } 205 206 bool empty() const { return Q.empty(); } 207 208 private: 209 std::deque<Value *> Q; 210 std::set<Value *> S; 211 }; 212 213 using ValueSetType = std::set<Value *>; 214 215 std::vector<Rule> Rules; 216 217 public: 218 struct Context { 219 using ValueMapType = DenseMap<Value *, Value *>; 220 221 Value *Root; 222 ValueSetType Used; // The set of all cloned values used by Root. 223 ValueSetType Clones; // The set of all cloned values. 224 LLVMContext &Ctx; 225 226 Context(Instruction *Exp) 227 : Ctx(Exp->getParent()->getParent()->getContext()) { 228 initialize(Exp); 229 } 230 231 ~Context() { cleanup(); } 232 233 void print(raw_ostream &OS, const Value *V) const; 234 Value *materialize(BasicBlock *B, BasicBlock::iterator At); 235 236 private: 237 friend struct Simplifier; 238 239 void initialize(Instruction *Exp); 240 void cleanup(); 241 242 template <typename FuncT> void traverse(Value *V, FuncT F); 243 void record(Value *V); 244 void use(Value *V); 245 void unuse(Value *V); 246 247 bool equal(const Instruction *I, const Instruction *J) const; 248 Value *find(Value *Tree, Value *Sub) const; 249 Value *subst(Value *Tree, Value *OldV, Value *NewV); 250 void replace(Value *OldV, Value *NewV); 251 void link(Instruction *I, BasicBlock *B, BasicBlock::iterator At); 252 }; 253 254 Value *simplify(Context &C); 255 }; 256 257 struct PE { 258 PE(const Simplifier::Context &c, Value *v = nullptr) : C(c), V(v) {} 259 260 const Simplifier::Context &C; 261 const Value *V; 262 }; 263 264 LLVM_ATTRIBUTE_USED 265 raw_ostream &operator<<(raw_ostream &OS, const PE &P) { 266 P.C.print(OS, P.V ? P.V : P.C.Root); 267 return OS; 268 } 269 270 } // end anonymous namespace 271 272 char HexagonLoopIdiomRecognizeLegacyPass::ID = 0; 273 274 INITIALIZE_PASS_BEGIN(HexagonLoopIdiomRecognizeLegacyPass, "hexagon-loop-idiom", 275 "Recognize Hexagon-specific loop idioms", false, false) 276 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 277 INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 278 INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) 279 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) 280 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 281 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 282 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 283 INITIALIZE_PASS_END(HexagonLoopIdiomRecognizeLegacyPass, "hexagon-loop-idiom", 284 "Recognize Hexagon-specific loop idioms", false, false) 285 286 template <typename FuncT> 287 void Simplifier::Context::traverse(Value *V, FuncT F) { 288 WorkListType Q; 289 Q.push_back(V); 290 291 while (!Q.empty()) { 292 Instruction *U = dyn_cast<Instruction>(Q.pop_front_val()); 293 if (!U || U->getParent()) 294 continue; 295 if (!F(U)) 296 continue; 297 for (Value *Op : U->operands()) 298 Q.push_back(Op); 299 } 300 } 301 302 void Simplifier::Context::print(raw_ostream &OS, const Value *V) const { 303 const auto *U = dyn_cast<const Instruction>(V); 304 if (!U) { 305 OS << V << '(' << *V << ')'; 306 return; 307 } 308 309 if (U->getParent()) { 310 OS << U << '('; 311 U->printAsOperand(OS, true); 312 OS << ')'; 313 return; 314 } 315 316 unsigned N = U->getNumOperands(); 317 if (N != 0) 318 OS << U << '('; 319 OS << U->getOpcodeName(); 320 for (const Value *Op : U->operands()) { 321 OS << ' '; 322 print(OS, Op); 323 } 324 if (N != 0) 325 OS << ')'; 326 } 327 328 void Simplifier::Context::initialize(Instruction *Exp) { 329 // Perform a deep clone of the expression, set Root to the root 330 // of the clone, and build a map from the cloned values to the 331 // original ones. 332 ValueMapType M; 333 BasicBlock *Block = Exp->getParent(); 334 WorkListType Q; 335 Q.push_back(Exp); 336 337 while (!Q.empty()) { 338 Value *V = Q.pop_front_val(); 339 if (M.contains(V)) 340 continue; 341 if (Instruction *U = dyn_cast<Instruction>(V)) { 342 if (isa<PHINode>(U) || U->getParent() != Block) 343 continue; 344 for (Value *Op : U->operands()) 345 Q.push_back(Op); 346 M.insert({U, U->clone()}); 347 } 348 } 349 350 for (std::pair<Value*,Value*> P : M) { 351 Instruction *U = cast<Instruction>(P.second); 352 for (unsigned i = 0, n = U->getNumOperands(); i != n; ++i) { 353 auto F = M.find(U->getOperand(i)); 354 if (F != M.end()) 355 U->setOperand(i, F->second); 356 } 357 } 358 359 auto R = M.find(Exp); 360 assert(R != M.end()); 361 Root = R->second; 362 363 record(Root); 364 use(Root); 365 } 366 367 void Simplifier::Context::record(Value *V) { 368 auto Record = [this](Instruction *U) -> bool { 369 Clones.insert(U); 370 return true; 371 }; 372 traverse(V, Record); 373 } 374 375 void Simplifier::Context::use(Value *V) { 376 auto Use = [this](Instruction *U) -> bool { 377 Used.insert(U); 378 return true; 379 }; 380 traverse(V, Use); 381 } 382 383 void Simplifier::Context::unuse(Value *V) { 384 if (!isa<Instruction>(V) || cast<Instruction>(V)->getParent() != nullptr) 385 return; 386 387 auto Unuse = [this](Instruction *U) -> bool { 388 if (!U->use_empty()) 389 return false; 390 Used.erase(U); 391 return true; 392 }; 393 traverse(V, Unuse); 394 } 395 396 Value *Simplifier::Context::subst(Value *Tree, Value *OldV, Value *NewV) { 397 if (Tree == OldV) 398 return NewV; 399 if (OldV == NewV) 400 return Tree; 401 402 WorkListType Q; 403 Q.push_back(Tree); 404 while (!Q.empty()) { 405 Instruction *U = dyn_cast<Instruction>(Q.pop_front_val()); 406 // If U is not an instruction, or it's not a clone, skip it. 407 if (!U || U->getParent()) 408 continue; 409 for (unsigned i = 0, n = U->getNumOperands(); i != n; ++i) { 410 Value *Op = U->getOperand(i); 411 if (Op == OldV) { 412 U->setOperand(i, NewV); 413 unuse(OldV); 414 } else { 415 Q.push_back(Op); 416 } 417 } 418 } 419 return Tree; 420 } 421 422 void Simplifier::Context::replace(Value *OldV, Value *NewV) { 423 if (Root == OldV) { 424 Root = NewV; 425 use(Root); 426 return; 427 } 428 429 // NewV may be a complex tree that has just been created by one of the 430 // transformation rules. We need to make sure that it is commoned with 431 // the existing Root to the maximum extent possible. 432 // Identify all subtrees of NewV (including NewV itself) that have 433 // equivalent counterparts in Root, and replace those subtrees with 434 // these counterparts. 435 WorkListType Q; 436 Q.push_back(NewV); 437 while (!Q.empty()) { 438 Value *V = Q.pop_front_val(); 439 Instruction *U = dyn_cast<Instruction>(V); 440 if (!U || U->getParent()) 441 continue; 442 if (Value *DupV = find(Root, V)) { 443 if (DupV != V) 444 NewV = subst(NewV, V, DupV); 445 } else { 446 for (Value *Op : U->operands()) 447 Q.push_back(Op); 448 } 449 } 450 451 // Now, simply replace OldV with NewV in Root. 452 Root = subst(Root, OldV, NewV); 453 use(Root); 454 } 455 456 void Simplifier::Context::cleanup() { 457 for (Value *V : Clones) { 458 Instruction *U = cast<Instruction>(V); 459 if (!U->getParent()) 460 U->dropAllReferences(); 461 } 462 463 for (Value *V : Clones) { 464 Instruction *U = cast<Instruction>(V); 465 if (!U->getParent()) 466 U->deleteValue(); 467 } 468 } 469 470 bool Simplifier::Context::equal(const Instruction *I, 471 const Instruction *J) const { 472 if (I == J) 473 return true; 474 if (!I->isSameOperationAs(J)) 475 return false; 476 if (isa<PHINode>(I)) 477 return I->isIdenticalTo(J); 478 479 for (unsigned i = 0, n = I->getNumOperands(); i != n; ++i) { 480 Value *OpI = I->getOperand(i), *OpJ = J->getOperand(i); 481 if (OpI == OpJ) 482 continue; 483 auto *InI = dyn_cast<const Instruction>(OpI); 484 auto *InJ = dyn_cast<const Instruction>(OpJ); 485 if (InI && InJ) { 486 if (!equal(InI, InJ)) 487 return false; 488 } else if (InI != InJ || !InI) 489 return false; 490 } 491 return true; 492 } 493 494 Value *Simplifier::Context::find(Value *Tree, Value *Sub) const { 495 Instruction *SubI = dyn_cast<Instruction>(Sub); 496 WorkListType Q; 497 Q.push_back(Tree); 498 499 while (!Q.empty()) { 500 Value *V = Q.pop_front_val(); 501 if (V == Sub) 502 return V; 503 Instruction *U = dyn_cast<Instruction>(V); 504 if (!U || U->getParent()) 505 continue; 506 if (SubI && equal(SubI, U)) 507 return U; 508 assert(!isa<PHINode>(U)); 509 for (Value *Op : U->operands()) 510 Q.push_back(Op); 511 } 512 return nullptr; 513 } 514 515 void Simplifier::Context::link(Instruction *I, BasicBlock *B, 516 BasicBlock::iterator At) { 517 if (I->getParent()) 518 return; 519 520 for (Value *Op : I->operands()) { 521 if (Instruction *OpI = dyn_cast<Instruction>(Op)) 522 link(OpI, B, At); 523 } 524 525 I->insertInto(B, At); 526 } 527 528 Value *Simplifier::Context::materialize(BasicBlock *B, 529 BasicBlock::iterator At) { 530 if (Instruction *RootI = dyn_cast<Instruction>(Root)) 531 link(RootI, B, At); 532 return Root; 533 } 534 535 Value *Simplifier::simplify(Context &C) { 536 WorkListType Q; 537 Q.push_back(C.Root); 538 unsigned Count = 0; 539 const unsigned Limit = SimplifyLimit; 540 541 while (!Q.empty()) { 542 if (Count++ >= Limit) 543 break; 544 Instruction *U = dyn_cast<Instruction>(Q.pop_front_val()); 545 if (!U || U->getParent() || !C.Used.count(U)) 546 continue; 547 bool Changed = false; 548 for (Rule &R : Rules) { 549 Value *W = R.Fn(U, C.Ctx); 550 if (!W) 551 continue; 552 Changed = true; 553 C.record(W); 554 C.replace(U, W); 555 Q.push_back(C.Root); 556 break; 557 } 558 if (!Changed) { 559 for (Value *Op : U->operands()) 560 Q.push_back(Op); 561 } 562 } 563 return Count < Limit ? C.Root : nullptr; 564 } 565 566 //===----------------------------------------------------------------------===// 567 // 568 // Implementation of PolynomialMultiplyRecognize 569 // 570 //===----------------------------------------------------------------------===// 571 572 namespace { 573 574 class PolynomialMultiplyRecognize { 575 public: 576 explicit PolynomialMultiplyRecognize(Loop *loop, const DataLayout &dl, 577 const DominatorTree &dt, const TargetLibraryInfo &tli, 578 ScalarEvolution &se) 579 : CurLoop(loop), DL(dl), DT(dt), TLI(tli), SE(se) {} 580 581 bool recognize(); 582 583 private: 584 using ValueSeq = SetVector<Value *>; 585 586 IntegerType *getPmpyType() const { 587 LLVMContext &Ctx = CurLoop->getHeader()->getParent()->getContext(); 588 return IntegerType::get(Ctx, 32); 589 } 590 591 bool isPromotableTo(Value *V, IntegerType *Ty); 592 void promoteTo(Instruction *In, IntegerType *DestTy, BasicBlock *LoopB); 593 bool promoteTypes(BasicBlock *LoopB, BasicBlock *ExitB); 594 595 Value *getCountIV(BasicBlock *BB); 596 bool findCycle(Value *Out, Value *In, ValueSeq &Cycle); 597 void classifyCycle(Instruction *DivI, ValueSeq &Cycle, ValueSeq &Early, 598 ValueSeq &Late); 599 bool classifyInst(Instruction *UseI, ValueSeq &Early, ValueSeq &Late); 600 bool commutesWithShift(Instruction *I); 601 bool highBitsAreZero(Value *V, unsigned IterCount); 602 bool keepsHighBitsZero(Value *V, unsigned IterCount); 603 bool isOperandShifted(Instruction *I, Value *Op); 604 bool convertShiftsToLeft(BasicBlock *LoopB, BasicBlock *ExitB, 605 unsigned IterCount); 606 void cleanupLoopBody(BasicBlock *LoopB); 607 608 struct ParsedValues { 609 ParsedValues() = default; 610 611 Value *M = nullptr; 612 Value *P = nullptr; 613 Value *Q = nullptr; 614 Value *R = nullptr; 615 Value *X = nullptr; 616 Instruction *Res = nullptr; 617 unsigned IterCount = 0; 618 bool Left = false; 619 bool Inv = false; 620 }; 621 622 bool matchLeftShift(SelectInst *SelI, Value *CIV, ParsedValues &PV); 623 bool matchRightShift(SelectInst *SelI, ParsedValues &PV); 624 bool scanSelect(SelectInst *SI, BasicBlock *LoopB, BasicBlock *PrehB, 625 Value *CIV, ParsedValues &PV, bool PreScan); 626 unsigned getInverseMxN(unsigned QP); 627 Value *generate(BasicBlock::iterator At, ParsedValues &PV); 628 629 void setupPreSimplifier(Simplifier &S); 630 void setupPostSimplifier(Simplifier &S); 631 632 Loop *CurLoop; 633 const DataLayout &DL; 634 const DominatorTree &DT; 635 const TargetLibraryInfo &TLI; 636 ScalarEvolution &SE; 637 }; 638 639 } // end anonymous namespace 640 641 Value *PolynomialMultiplyRecognize::getCountIV(BasicBlock *BB) { 642 pred_iterator PI = pred_begin(BB), PE = pred_end(BB); 643 if (std::distance(PI, PE) != 2) 644 return nullptr; 645 BasicBlock *PB = (*PI == BB) ? *std::next(PI) : *PI; 646 647 for (auto I = BB->begin(), E = BB->end(); I != E && isa<PHINode>(I); ++I) { 648 auto *PN = cast<PHINode>(I); 649 Value *InitV = PN->getIncomingValueForBlock(PB); 650 if (!isa<ConstantInt>(InitV) || !cast<ConstantInt>(InitV)->isZero()) 651 continue; 652 Value *IterV = PN->getIncomingValueForBlock(BB); 653 auto *BO = dyn_cast<BinaryOperator>(IterV); 654 if (!BO) 655 continue; 656 if (BO->getOpcode() != Instruction::Add) 657 continue; 658 Value *IncV = nullptr; 659 if (BO->getOperand(0) == PN) 660 IncV = BO->getOperand(1); 661 else if (BO->getOperand(1) == PN) 662 IncV = BO->getOperand(0); 663 if (IncV == nullptr) 664 continue; 665 666 if (auto *T = dyn_cast<ConstantInt>(IncV)) 667 if (T->isOne()) 668 return PN; 669 } 670 return nullptr; 671 } 672 673 static void replaceAllUsesOfWithIn(Value *I, Value *J, BasicBlock *BB) { 674 for (auto UI = I->user_begin(), UE = I->user_end(); UI != UE;) { 675 Use &TheUse = UI.getUse(); 676 ++UI; 677 if (auto *II = dyn_cast<Instruction>(TheUse.getUser())) 678 if (BB == II->getParent()) 679 II->replaceUsesOfWith(I, J); 680 } 681 } 682 683 bool PolynomialMultiplyRecognize::matchLeftShift(SelectInst *SelI, 684 Value *CIV, ParsedValues &PV) { 685 // Match the following: 686 // select (X & (1 << i)) != 0 ? R ^ (Q << i) : R 687 // select (X & (1 << i)) == 0 ? R : R ^ (Q << i) 688 // The condition may also check for equality with the masked value, i.e 689 // select (X & (1 << i)) == (1 << i) ? R ^ (Q << i) : R 690 // select (X & (1 << i)) != (1 << i) ? R : R ^ (Q << i); 691 692 Value *CondV = SelI->getCondition(); 693 Value *TrueV = SelI->getTrueValue(); 694 Value *FalseV = SelI->getFalseValue(); 695 696 using namespace PatternMatch; 697 698 CmpInst::Predicate P; 699 Value *A = nullptr, *B = nullptr, *C = nullptr; 700 701 if (!match(CondV, m_ICmp(P, m_And(m_Value(A), m_Value(B)), m_Value(C))) && 702 !match(CondV, m_ICmp(P, m_Value(C), m_And(m_Value(A), m_Value(B))))) 703 return false; 704 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE) 705 return false; 706 // Matched: select (A & B) == C ? ... : ... 707 // select (A & B) != C ? ... : ... 708 709 Value *X = nullptr, *Sh1 = nullptr; 710 // Check (A & B) for (X & (1 << i)): 711 if (match(A, m_Shl(m_One(), m_Specific(CIV)))) { 712 Sh1 = A; 713 X = B; 714 } else if (match(B, m_Shl(m_One(), m_Specific(CIV)))) { 715 Sh1 = B; 716 X = A; 717 } else { 718 // TODO: Could also check for an induction variable containing single 719 // bit shifted left by 1 in each iteration. 720 return false; 721 } 722 723 bool TrueIfZero; 724 725 // Check C against the possible values for comparison: 0 and (1 << i): 726 if (match(C, m_Zero())) 727 TrueIfZero = (P == CmpInst::ICMP_EQ); 728 else if (C == Sh1) 729 TrueIfZero = (P == CmpInst::ICMP_NE); 730 else 731 return false; 732 733 // So far, matched: 734 // select (X & (1 << i)) ? ... : ... 735 // including variations of the check against zero/non-zero value. 736 737 Value *ShouldSameV = nullptr, *ShouldXoredV = nullptr; 738 if (TrueIfZero) { 739 ShouldSameV = TrueV; 740 ShouldXoredV = FalseV; 741 } else { 742 ShouldSameV = FalseV; 743 ShouldXoredV = TrueV; 744 } 745 746 Value *Q = nullptr, *R = nullptr, *Y = nullptr, *Z = nullptr; 747 Value *T = nullptr; 748 if (match(ShouldXoredV, m_Xor(m_Value(Y), m_Value(Z)))) { 749 // Matched: select +++ ? ... : Y ^ Z 750 // select +++ ? Y ^ Z : ... 751 // where +++ denotes previously checked matches. 752 if (ShouldSameV == Y) 753 T = Z; 754 else if (ShouldSameV == Z) 755 T = Y; 756 else 757 return false; 758 R = ShouldSameV; 759 // Matched: select +++ ? R : R ^ T 760 // select +++ ? R ^ T : R 761 // depending on TrueIfZero. 762 763 } else if (match(ShouldSameV, m_Zero())) { 764 // Matched: select +++ ? 0 : ... 765 // select +++ ? ... : 0 766 if (!SelI->hasOneUse()) 767 return false; 768 T = ShouldXoredV; 769 // Matched: select +++ ? 0 : T 770 // select +++ ? T : 0 771 772 Value *U = *SelI->user_begin(); 773 if (!match(U, m_c_Xor(m_Specific(SelI), m_Value(R)))) 774 return false; 775 // Matched: xor (select +++ ? 0 : T), R 776 // xor (select +++ ? T : 0), R 777 } else 778 return false; 779 780 // The xor input value T is isolated into its own match so that it could 781 // be checked against an induction variable containing a shifted bit 782 // (todo). 783 // For now, check against (Q << i). 784 if (!match(T, m_Shl(m_Value(Q), m_Specific(CIV))) && 785 !match(T, m_Shl(m_ZExt(m_Value(Q)), m_ZExt(m_Specific(CIV))))) 786 return false; 787 // Matched: select +++ ? R : R ^ (Q << i) 788 // select +++ ? R ^ (Q << i) : R 789 790 PV.X = X; 791 PV.Q = Q; 792 PV.R = R; 793 PV.Left = true; 794 return true; 795 } 796 797 bool PolynomialMultiplyRecognize::matchRightShift(SelectInst *SelI, 798 ParsedValues &PV) { 799 // Match the following: 800 // select (X & 1) != 0 ? (R >> 1) ^ Q : (R >> 1) 801 // select (X & 1) == 0 ? (R >> 1) : (R >> 1) ^ Q 802 // The condition may also check for equality with the masked value, i.e 803 // select (X & 1) == 1 ? (R >> 1) ^ Q : (R >> 1) 804 // select (X & 1) != 1 ? (R >> 1) : (R >> 1) ^ Q 805 806 Value *CondV = SelI->getCondition(); 807 Value *TrueV = SelI->getTrueValue(); 808 Value *FalseV = SelI->getFalseValue(); 809 810 using namespace PatternMatch; 811 812 Value *C = nullptr; 813 CmpInst::Predicate P; 814 bool TrueIfZero; 815 816 if (match(CondV, m_c_ICmp(P, m_Value(C), m_Zero()))) { 817 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE) 818 return false; 819 // Matched: select C == 0 ? ... : ... 820 // select C != 0 ? ... : ... 821 TrueIfZero = (P == CmpInst::ICMP_EQ); 822 } else if (match(CondV, m_c_ICmp(P, m_Value(C), m_One()))) { 823 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE) 824 return false; 825 // Matched: select C == 1 ? ... : ... 826 // select C != 1 ? ... : ... 827 TrueIfZero = (P == CmpInst::ICMP_NE); 828 } else 829 return false; 830 831 Value *X = nullptr; 832 if (!match(C, m_And(m_Value(X), m_One()))) 833 return false; 834 // Matched: select (X & 1) == +++ ? ... : ... 835 // select (X & 1) != +++ ? ... : ... 836 837 Value *R = nullptr, *Q = nullptr; 838 if (TrueIfZero) { 839 // The select's condition is true if the tested bit is 0. 840 // TrueV must be the shift, FalseV must be the xor. 841 if (!match(TrueV, m_LShr(m_Value(R), m_One()))) 842 return false; 843 // Matched: select +++ ? (R >> 1) : ... 844 if (!match(FalseV, m_c_Xor(m_Specific(TrueV), m_Value(Q)))) 845 return false; 846 // Matched: select +++ ? (R >> 1) : (R >> 1) ^ Q 847 // with commuting ^. 848 } else { 849 // The select's condition is true if the tested bit is 1. 850 // TrueV must be the xor, FalseV must be the shift. 851 if (!match(FalseV, m_LShr(m_Value(R), m_One()))) 852 return false; 853 // Matched: select +++ ? ... : (R >> 1) 854 if (!match(TrueV, m_c_Xor(m_Specific(FalseV), m_Value(Q)))) 855 return false; 856 // Matched: select +++ ? (R >> 1) ^ Q : (R >> 1) 857 // with commuting ^. 858 } 859 860 PV.X = X; 861 PV.Q = Q; 862 PV.R = R; 863 PV.Left = false; 864 return true; 865 } 866 867 bool PolynomialMultiplyRecognize::scanSelect(SelectInst *SelI, 868 BasicBlock *LoopB, BasicBlock *PrehB, Value *CIV, ParsedValues &PV, 869 bool PreScan) { 870 using namespace PatternMatch; 871 872 // The basic pattern for R = P.Q is: 873 // for i = 0..31 874 // R = phi (0, R') 875 // if (P & (1 << i)) ; test-bit(P, i) 876 // R' = R ^ (Q << i) 877 // 878 // Similarly, the basic pattern for R = (P/Q).Q - P 879 // for i = 0..31 880 // R = phi(P, R') 881 // if (R & (1 << i)) 882 // R' = R ^ (Q << i) 883 884 // There exist idioms, where instead of Q being shifted left, P is shifted 885 // right. This produces a result that is shifted right by 32 bits (the 886 // non-shifted result is 64-bit). 887 // 888 // For R = P.Q, this would be: 889 // for i = 0..31 890 // R = phi (0, R') 891 // if ((P >> i) & 1) 892 // R' = (R >> 1) ^ Q ; R is cycled through the loop, so it must 893 // else ; be shifted by 1, not i. 894 // R' = R >> 1 895 // 896 // And for the inverse: 897 // for i = 0..31 898 // R = phi (P, R') 899 // if (R & 1) 900 // R' = (R >> 1) ^ Q 901 // else 902 // R' = R >> 1 903 904 // The left-shifting idioms share the same pattern: 905 // select (X & (1 << i)) ? R ^ (Q << i) : R 906 // Similarly for right-shifting idioms: 907 // select (X & 1) ? (R >> 1) ^ Q 908 909 if (matchLeftShift(SelI, CIV, PV)) { 910 // If this is a pre-scan, getting this far is sufficient. 911 if (PreScan) 912 return true; 913 914 // Need to make sure that the SelI goes back into R. 915 auto *RPhi = dyn_cast<PHINode>(PV.R); 916 if (!RPhi) 917 return false; 918 if (SelI != RPhi->getIncomingValueForBlock(LoopB)) 919 return false; 920 PV.Res = SelI; 921 922 // If X is loop invariant, it must be the input polynomial, and the 923 // idiom is the basic polynomial multiply. 924 if (CurLoop->isLoopInvariant(PV.X)) { 925 PV.P = PV.X; 926 PV.Inv = false; 927 } else { 928 // X is not loop invariant. If X == R, this is the inverse pmpy. 929 // Otherwise, check for an xor with an invariant value. If the 930 // variable argument to the xor is R, then this is still a valid 931 // inverse pmpy. 932 PV.Inv = true; 933 if (PV.X != PV.R) { 934 Value *Var = nullptr, *Inv = nullptr, *X1 = nullptr, *X2 = nullptr; 935 if (!match(PV.X, m_Xor(m_Value(X1), m_Value(X2)))) 936 return false; 937 auto *I1 = dyn_cast<Instruction>(X1); 938 auto *I2 = dyn_cast<Instruction>(X2); 939 if (!I1 || I1->getParent() != LoopB) { 940 Var = X2; 941 Inv = X1; 942 } else if (!I2 || I2->getParent() != LoopB) { 943 Var = X1; 944 Inv = X2; 945 } else 946 return false; 947 if (Var != PV.R) 948 return false; 949 PV.M = Inv; 950 } 951 // The input polynomial P still needs to be determined. It will be 952 // the entry value of R. 953 Value *EntryP = RPhi->getIncomingValueForBlock(PrehB); 954 PV.P = EntryP; 955 } 956 957 return true; 958 } 959 960 if (matchRightShift(SelI, PV)) { 961 // If this is an inverse pattern, the Q polynomial must be known at 962 // compile time. 963 if (PV.Inv && !isa<ConstantInt>(PV.Q)) 964 return false; 965 if (PreScan) 966 return true; 967 // There is no exact matching of right-shift pmpy. 968 return false; 969 } 970 971 return false; 972 } 973 974 bool PolynomialMultiplyRecognize::isPromotableTo(Value *Val, 975 IntegerType *DestTy) { 976 IntegerType *T = dyn_cast<IntegerType>(Val->getType()); 977 if (!T || T->getBitWidth() > DestTy->getBitWidth()) 978 return false; 979 if (T->getBitWidth() == DestTy->getBitWidth()) 980 return true; 981 // Non-instructions are promotable. The reason why an instruction may not 982 // be promotable is that it may produce a different result if its operands 983 // and the result are promoted, for example, it may produce more non-zero 984 // bits. While it would still be possible to represent the proper result 985 // in a wider type, it may require adding additional instructions (which 986 // we don't want to do). 987 Instruction *In = dyn_cast<Instruction>(Val); 988 if (!In) 989 return true; 990 // The bitwidth of the source type is smaller than the destination. 991 // Check if the individual operation can be promoted. 992 switch (In->getOpcode()) { 993 case Instruction::PHI: 994 case Instruction::ZExt: 995 case Instruction::And: 996 case Instruction::Or: 997 case Instruction::Xor: 998 case Instruction::LShr: // Shift right is ok. 999 case Instruction::Select: 1000 case Instruction::Trunc: 1001 return true; 1002 case Instruction::ICmp: 1003 if (CmpInst *CI = cast<CmpInst>(In)) 1004 return CI->isEquality() || CI->isUnsigned(); 1005 llvm_unreachable("Cast failed unexpectedly"); 1006 case Instruction::Add: 1007 return In->hasNoSignedWrap() && In->hasNoUnsignedWrap(); 1008 } 1009 return false; 1010 } 1011 1012 void PolynomialMultiplyRecognize::promoteTo(Instruction *In, 1013 IntegerType *DestTy, BasicBlock *LoopB) { 1014 Type *OrigTy = In->getType(); 1015 assert(!OrigTy->isVoidTy() && "Invalid instruction to promote"); 1016 1017 // Leave boolean values alone. 1018 if (!In->getType()->isIntegerTy(1)) 1019 In->mutateType(DestTy); 1020 unsigned DestBW = DestTy->getBitWidth(); 1021 1022 // Handle PHIs. 1023 if (PHINode *P = dyn_cast<PHINode>(In)) { 1024 unsigned N = P->getNumIncomingValues(); 1025 for (unsigned i = 0; i != N; ++i) { 1026 BasicBlock *InB = P->getIncomingBlock(i); 1027 if (InB == LoopB) 1028 continue; 1029 Value *InV = P->getIncomingValue(i); 1030 IntegerType *Ty = cast<IntegerType>(InV->getType()); 1031 // Do not promote values in PHI nodes of type i1. 1032 if (Ty != P->getType()) { 1033 // If the value type does not match the PHI type, the PHI type 1034 // must have been promoted. 1035 assert(Ty->getBitWidth() < DestBW); 1036 InV = IRBuilder<>(InB->getTerminator()).CreateZExt(InV, DestTy); 1037 P->setIncomingValue(i, InV); 1038 } 1039 } 1040 } else if (ZExtInst *Z = dyn_cast<ZExtInst>(In)) { 1041 Value *Op = Z->getOperand(0); 1042 if (Op->getType() == Z->getType()) 1043 Z->replaceAllUsesWith(Op); 1044 Z->eraseFromParent(); 1045 return; 1046 } 1047 if (TruncInst *T = dyn_cast<TruncInst>(In)) { 1048 IntegerType *TruncTy = cast<IntegerType>(OrigTy); 1049 Value *Mask = ConstantInt::get(DestTy, (1u << TruncTy->getBitWidth()) - 1); 1050 Value *And = IRBuilder<>(In).CreateAnd(T->getOperand(0), Mask); 1051 T->replaceAllUsesWith(And); 1052 T->eraseFromParent(); 1053 return; 1054 } 1055 1056 // Promote immediates. 1057 for (unsigned i = 0, n = In->getNumOperands(); i != n; ++i) { 1058 if (ConstantInt *CI = dyn_cast<ConstantInt>(In->getOperand(i))) 1059 if (CI->getBitWidth() < DestBW) 1060 In->setOperand(i, ConstantInt::get(DestTy, CI->getZExtValue())); 1061 } 1062 } 1063 1064 bool PolynomialMultiplyRecognize::promoteTypes(BasicBlock *LoopB, 1065 BasicBlock *ExitB) { 1066 assert(LoopB); 1067 // Skip loops where the exit block has more than one predecessor. The values 1068 // coming from the loop block will be promoted to another type, and so the 1069 // values coming into the exit block from other predecessors would also have 1070 // to be promoted. 1071 if (!ExitB || (ExitB->getSinglePredecessor() != LoopB)) 1072 return false; 1073 IntegerType *DestTy = getPmpyType(); 1074 // Check if the exit values have types that are no wider than the type 1075 // that we want to promote to. 1076 unsigned DestBW = DestTy->getBitWidth(); 1077 for (PHINode &P : ExitB->phis()) { 1078 if (P.getNumIncomingValues() != 1) 1079 return false; 1080 assert(P.getIncomingBlock(0) == LoopB); 1081 IntegerType *T = dyn_cast<IntegerType>(P.getType()); 1082 if (!T || T->getBitWidth() > DestBW) 1083 return false; 1084 } 1085 1086 // Check all instructions in the loop. 1087 for (Instruction &In : *LoopB) 1088 if (!In.isTerminator() && !isPromotableTo(&In, DestTy)) 1089 return false; 1090 1091 // Perform the promotion. 1092 std::vector<Instruction*> LoopIns; 1093 std::transform(LoopB->begin(), LoopB->end(), std::back_inserter(LoopIns), 1094 [](Instruction &In) { return &In; }); 1095 for (Instruction *In : LoopIns) 1096 if (!In->isTerminator()) 1097 promoteTo(In, DestTy, LoopB); 1098 1099 // Fix up the PHI nodes in the exit block. 1100 Instruction *EndI = ExitB->getFirstNonPHI(); 1101 BasicBlock::iterator End = EndI ? EndI->getIterator() : ExitB->end(); 1102 for (auto I = ExitB->begin(); I != End; ++I) { 1103 PHINode *P = dyn_cast<PHINode>(I); 1104 if (!P) 1105 break; 1106 Type *Ty0 = P->getIncomingValue(0)->getType(); 1107 Type *PTy = P->getType(); 1108 if (PTy != Ty0) { 1109 assert(Ty0 == DestTy); 1110 // In order to create the trunc, P must have the promoted type. 1111 P->mutateType(Ty0); 1112 Value *T = IRBuilder<>(ExitB, End).CreateTrunc(P, PTy); 1113 // In order for the RAUW to work, the types of P and T must match. 1114 P->mutateType(PTy); 1115 P->replaceAllUsesWith(T); 1116 // Final update of the P's type. 1117 P->mutateType(Ty0); 1118 cast<Instruction>(T)->setOperand(0, P); 1119 } 1120 } 1121 1122 return true; 1123 } 1124 1125 bool PolynomialMultiplyRecognize::findCycle(Value *Out, Value *In, 1126 ValueSeq &Cycle) { 1127 // Out = ..., In, ... 1128 if (Out == In) 1129 return true; 1130 1131 auto *BB = cast<Instruction>(Out)->getParent(); 1132 bool HadPhi = false; 1133 1134 for (auto *U : Out->users()) { 1135 auto *I = dyn_cast<Instruction>(&*U); 1136 if (I == nullptr || I->getParent() != BB) 1137 continue; 1138 // Make sure that there are no multi-iteration cycles, e.g. 1139 // p1 = phi(p2) 1140 // p2 = phi(p1) 1141 // The cycle p1->p2->p1 would span two loop iterations. 1142 // Check that there is only one phi in the cycle. 1143 bool IsPhi = isa<PHINode>(I); 1144 if (IsPhi && HadPhi) 1145 return false; 1146 HadPhi |= IsPhi; 1147 if (!Cycle.insert(I)) 1148 return false; 1149 if (findCycle(I, In, Cycle)) 1150 break; 1151 Cycle.remove(I); 1152 } 1153 return !Cycle.empty(); 1154 } 1155 1156 void PolynomialMultiplyRecognize::classifyCycle(Instruction *DivI, 1157 ValueSeq &Cycle, ValueSeq &Early, ValueSeq &Late) { 1158 // All the values in the cycle that are between the phi node and the 1159 // divider instruction will be classified as "early", all other values 1160 // will be "late". 1161 1162 bool IsE = true; 1163 unsigned I, N = Cycle.size(); 1164 for (I = 0; I < N; ++I) { 1165 Value *V = Cycle[I]; 1166 if (DivI == V) 1167 IsE = false; 1168 else if (!isa<PHINode>(V)) 1169 continue; 1170 // Stop if found either. 1171 break; 1172 } 1173 // "I" is the index of either DivI or the phi node, whichever was first. 1174 // "E" is "false" or "true" respectively. 1175 ValueSeq &First = !IsE ? Early : Late; 1176 for (unsigned J = 0; J < I; ++J) 1177 First.insert(Cycle[J]); 1178 1179 ValueSeq &Second = IsE ? Early : Late; 1180 Second.insert(Cycle[I]); 1181 for (++I; I < N; ++I) { 1182 Value *V = Cycle[I]; 1183 if (DivI == V || isa<PHINode>(V)) 1184 break; 1185 Second.insert(V); 1186 } 1187 1188 for (; I < N; ++I) 1189 First.insert(Cycle[I]); 1190 } 1191 1192 bool PolynomialMultiplyRecognize::classifyInst(Instruction *UseI, 1193 ValueSeq &Early, ValueSeq &Late) { 1194 // Select is an exception, since the condition value does not have to be 1195 // classified in the same way as the true/false values. The true/false 1196 // values do have to be both early or both late. 1197 if (UseI->getOpcode() == Instruction::Select) { 1198 Value *TV = UseI->getOperand(1), *FV = UseI->getOperand(2); 1199 if (Early.count(TV) || Early.count(FV)) { 1200 if (Late.count(TV) || Late.count(FV)) 1201 return false; 1202 Early.insert(UseI); 1203 } else if (Late.count(TV) || Late.count(FV)) { 1204 if (Early.count(TV) || Early.count(FV)) 1205 return false; 1206 Late.insert(UseI); 1207 } 1208 return true; 1209 } 1210 1211 // Not sure what would be the example of this, but the code below relies 1212 // on having at least one operand. 1213 if (UseI->getNumOperands() == 0) 1214 return true; 1215 1216 bool AE = true, AL = true; 1217 for (auto &I : UseI->operands()) { 1218 if (Early.count(&*I)) 1219 AL = false; 1220 else if (Late.count(&*I)) 1221 AE = false; 1222 } 1223 // If the operands appear "all early" and "all late" at the same time, 1224 // then it means that none of them are actually classified as either. 1225 // This is harmless. 1226 if (AE && AL) 1227 return true; 1228 // Conversely, if they are neither "all early" nor "all late", then 1229 // we have a mixture of early and late operands that is not a known 1230 // exception. 1231 if (!AE && !AL) 1232 return false; 1233 1234 // Check that we have covered the two special cases. 1235 assert(AE != AL); 1236 1237 if (AE) 1238 Early.insert(UseI); 1239 else 1240 Late.insert(UseI); 1241 return true; 1242 } 1243 1244 bool PolynomialMultiplyRecognize::commutesWithShift(Instruction *I) { 1245 switch (I->getOpcode()) { 1246 case Instruction::And: 1247 case Instruction::Or: 1248 case Instruction::Xor: 1249 case Instruction::LShr: 1250 case Instruction::Shl: 1251 case Instruction::Select: 1252 case Instruction::ICmp: 1253 case Instruction::PHI: 1254 break; 1255 default: 1256 return false; 1257 } 1258 return true; 1259 } 1260 1261 bool PolynomialMultiplyRecognize::highBitsAreZero(Value *V, 1262 unsigned IterCount) { 1263 auto *T = dyn_cast<IntegerType>(V->getType()); 1264 if (!T) 1265 return false; 1266 1267 KnownBits Known(T->getBitWidth()); 1268 computeKnownBits(V, Known, DL); 1269 return Known.countMinLeadingZeros() >= IterCount; 1270 } 1271 1272 bool PolynomialMultiplyRecognize::keepsHighBitsZero(Value *V, 1273 unsigned IterCount) { 1274 // Assume that all inputs to the value have the high bits zero. 1275 // Check if the value itself preserves the zeros in the high bits. 1276 if (auto *C = dyn_cast<ConstantInt>(V)) 1277 return C->getValue().countl_zero() >= IterCount; 1278 1279 if (auto *I = dyn_cast<Instruction>(V)) { 1280 switch (I->getOpcode()) { 1281 case Instruction::And: 1282 case Instruction::Or: 1283 case Instruction::Xor: 1284 case Instruction::LShr: 1285 case Instruction::Select: 1286 case Instruction::ICmp: 1287 case Instruction::PHI: 1288 case Instruction::ZExt: 1289 return true; 1290 } 1291 } 1292 1293 return false; 1294 } 1295 1296 bool PolynomialMultiplyRecognize::isOperandShifted(Instruction *I, Value *Op) { 1297 unsigned Opc = I->getOpcode(); 1298 if (Opc == Instruction::Shl || Opc == Instruction::LShr) 1299 return Op != I->getOperand(1); 1300 return true; 1301 } 1302 1303 bool PolynomialMultiplyRecognize::convertShiftsToLeft(BasicBlock *LoopB, 1304 BasicBlock *ExitB, unsigned IterCount) { 1305 Value *CIV = getCountIV(LoopB); 1306 if (CIV == nullptr) 1307 return false; 1308 auto *CIVTy = dyn_cast<IntegerType>(CIV->getType()); 1309 if (CIVTy == nullptr) 1310 return false; 1311 1312 ValueSeq RShifts; 1313 ValueSeq Early, Late, Cycled; 1314 1315 // Find all value cycles that contain logical right shifts by 1. 1316 for (Instruction &I : *LoopB) { 1317 using namespace PatternMatch; 1318 1319 Value *V = nullptr; 1320 if (!match(&I, m_LShr(m_Value(V), m_One()))) 1321 continue; 1322 ValueSeq C; 1323 if (!findCycle(&I, V, C)) 1324 continue; 1325 1326 // Found a cycle. 1327 C.insert(&I); 1328 classifyCycle(&I, C, Early, Late); 1329 Cycled.insert(C.begin(), C.end()); 1330 RShifts.insert(&I); 1331 } 1332 1333 // Find the set of all values affected by the shift cycles, i.e. all 1334 // cycled values, and (recursively) all their users. 1335 ValueSeq Users(Cycled.begin(), Cycled.end()); 1336 for (unsigned i = 0; i < Users.size(); ++i) { 1337 Value *V = Users[i]; 1338 if (!isa<IntegerType>(V->getType())) 1339 return false; 1340 auto *R = cast<Instruction>(V); 1341 // If the instruction does not commute with shifts, the loop cannot 1342 // be unshifted. 1343 if (!commutesWithShift(R)) 1344 return false; 1345 for (User *U : R->users()) { 1346 auto *T = cast<Instruction>(U); 1347 // Skip users from outside of the loop. They will be handled later. 1348 // Also, skip the right-shifts and phi nodes, since they mix early 1349 // and late values. 1350 if (T->getParent() != LoopB || RShifts.count(T) || isa<PHINode>(T)) 1351 continue; 1352 1353 Users.insert(T); 1354 if (!classifyInst(T, Early, Late)) 1355 return false; 1356 } 1357 } 1358 1359 if (Users.empty()) 1360 return false; 1361 1362 // Verify that high bits remain zero. 1363 ValueSeq Internal(Users.begin(), Users.end()); 1364 ValueSeq Inputs; 1365 for (unsigned i = 0; i < Internal.size(); ++i) { 1366 auto *R = dyn_cast<Instruction>(Internal[i]); 1367 if (!R) 1368 continue; 1369 for (Value *Op : R->operands()) { 1370 auto *T = dyn_cast<Instruction>(Op); 1371 if (T && T->getParent() != LoopB) 1372 Inputs.insert(Op); 1373 else 1374 Internal.insert(Op); 1375 } 1376 } 1377 for (Value *V : Inputs) 1378 if (!highBitsAreZero(V, IterCount)) 1379 return false; 1380 for (Value *V : Internal) 1381 if (!keepsHighBitsZero(V, IterCount)) 1382 return false; 1383 1384 // Finally, the work can be done. Unshift each user. 1385 IRBuilder<> IRB(LoopB); 1386 std::map<Value*,Value*> ShiftMap; 1387 1388 using CastMapType = std::map<std::pair<Value *, Type *>, Value *>; 1389 1390 CastMapType CastMap; 1391 1392 auto upcast = [] (CastMapType &CM, IRBuilder<> &IRB, Value *V, 1393 IntegerType *Ty) -> Value* { 1394 auto H = CM.find(std::make_pair(V, Ty)); 1395 if (H != CM.end()) 1396 return H->second; 1397 Value *CV = IRB.CreateIntCast(V, Ty, false); 1398 CM.insert(std::make_pair(std::make_pair(V, Ty), CV)); 1399 return CV; 1400 }; 1401 1402 for (auto I = LoopB->begin(), E = LoopB->end(); I != E; ++I) { 1403 using namespace PatternMatch; 1404 1405 if (isa<PHINode>(I) || !Users.count(&*I)) 1406 continue; 1407 1408 // Match lshr x, 1. 1409 Value *V = nullptr; 1410 if (match(&*I, m_LShr(m_Value(V), m_One()))) { 1411 replaceAllUsesOfWithIn(&*I, V, LoopB); 1412 continue; 1413 } 1414 // For each non-cycled operand, replace it with the corresponding 1415 // value shifted left. 1416 for (auto &J : I->operands()) { 1417 Value *Op = J.get(); 1418 if (!isOperandShifted(&*I, Op)) 1419 continue; 1420 if (Users.count(Op)) 1421 continue; 1422 // Skip shifting zeros. 1423 if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero()) 1424 continue; 1425 // Check if we have already generated a shift for this value. 1426 auto F = ShiftMap.find(Op); 1427 Value *W = (F != ShiftMap.end()) ? F->second : nullptr; 1428 if (W == nullptr) { 1429 IRB.SetInsertPoint(&*I); 1430 // First, the shift amount will be CIV or CIV+1, depending on 1431 // whether the value is early or late. Instead of creating CIV+1, 1432 // do a single shift of the value. 1433 Value *ShAmt = CIV, *ShVal = Op; 1434 auto *VTy = cast<IntegerType>(ShVal->getType()); 1435 auto *ATy = cast<IntegerType>(ShAmt->getType()); 1436 if (Late.count(&*I)) 1437 ShVal = IRB.CreateShl(Op, ConstantInt::get(VTy, 1)); 1438 // Second, the types of the shifted value and the shift amount 1439 // must match. 1440 if (VTy != ATy) { 1441 if (VTy->getBitWidth() < ATy->getBitWidth()) 1442 ShVal = upcast(CastMap, IRB, ShVal, ATy); 1443 else 1444 ShAmt = upcast(CastMap, IRB, ShAmt, VTy); 1445 } 1446 // Ready to generate the shift and memoize it. 1447 W = IRB.CreateShl(ShVal, ShAmt); 1448 ShiftMap.insert(std::make_pair(Op, W)); 1449 } 1450 I->replaceUsesOfWith(Op, W); 1451 } 1452 } 1453 1454 // Update the users outside of the loop to account for having left 1455 // shifts. They would normally be shifted right in the loop, so shift 1456 // them right after the loop exit. 1457 // Take advantage of the loop-closed SSA form, which has all the post- 1458 // loop values in phi nodes. 1459 IRB.SetInsertPoint(ExitB, ExitB->getFirstInsertionPt()); 1460 for (auto P = ExitB->begin(), Q = ExitB->end(); P != Q; ++P) { 1461 if (!isa<PHINode>(P)) 1462 break; 1463 auto *PN = cast<PHINode>(P); 1464 Value *U = PN->getIncomingValueForBlock(LoopB); 1465 if (!Users.count(U)) 1466 continue; 1467 Value *S = IRB.CreateLShr(PN, ConstantInt::get(PN->getType(), IterCount)); 1468 PN->replaceAllUsesWith(S); 1469 // The above RAUW will create 1470 // S = lshr S, IterCount 1471 // so we need to fix it back into 1472 // S = lshr PN, IterCount 1473 cast<User>(S)->replaceUsesOfWith(S, PN); 1474 } 1475 1476 return true; 1477 } 1478 1479 void PolynomialMultiplyRecognize::cleanupLoopBody(BasicBlock *LoopB) { 1480 for (auto &I : *LoopB) 1481 if (Value *SV = simplifyInstruction(&I, {DL, &TLI, &DT})) 1482 I.replaceAllUsesWith(SV); 1483 1484 for (Instruction &I : llvm::make_early_inc_range(*LoopB)) 1485 RecursivelyDeleteTriviallyDeadInstructions(&I, &TLI); 1486 } 1487 1488 unsigned PolynomialMultiplyRecognize::getInverseMxN(unsigned QP) { 1489 // Arrays of coefficients of Q and the inverse, C. 1490 // Q[i] = coefficient at x^i. 1491 std::array<char,32> Q, C; 1492 1493 for (unsigned i = 0; i < 32; ++i) { 1494 Q[i] = QP & 1; 1495 QP >>= 1; 1496 } 1497 assert(Q[0] == 1); 1498 1499 // Find C, such that 1500 // (Q[n]*x^n + ... + Q[1]*x + Q[0]) * (C[n]*x^n + ... + C[1]*x + C[0]) = 1 1501 // 1502 // For it to have a solution, Q[0] must be 1. Since this is Z2[x], the 1503 // operations * and + are & and ^ respectively. 1504 // 1505 // Find C[i] recursively, by comparing i-th coefficient in the product 1506 // with 0 (or 1 for i=0). 1507 // 1508 // C[0] = 1, since C[0] = Q[0], and Q[0] = 1. 1509 C[0] = 1; 1510 for (unsigned i = 1; i < 32; ++i) { 1511 // Solve for C[i] in: 1512 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i]Q[0] = 0 1513 // This is equivalent to 1514 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i] = 0 1515 // which is 1516 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] = C[i] 1517 unsigned T = 0; 1518 for (unsigned j = 0; j < i; ++j) 1519 T = T ^ (C[j] & Q[i-j]); 1520 C[i] = T; 1521 } 1522 1523 unsigned QV = 0; 1524 for (unsigned i = 0; i < 32; ++i) 1525 if (C[i]) 1526 QV |= (1 << i); 1527 1528 return QV; 1529 } 1530 1531 Value *PolynomialMultiplyRecognize::generate(BasicBlock::iterator At, 1532 ParsedValues &PV) { 1533 IRBuilder<> B(&*At); 1534 Module *M = At->getParent()->getParent()->getParent(); 1535 Function *PMF = Intrinsic::getDeclaration(M, Intrinsic::hexagon_M4_pmpyw); 1536 1537 Value *P = PV.P, *Q = PV.Q, *P0 = P; 1538 unsigned IC = PV.IterCount; 1539 1540 if (PV.M != nullptr) 1541 P0 = P = B.CreateXor(P, PV.M); 1542 1543 // Create a bit mask to clear the high bits beyond IterCount. 1544 auto *BMI = ConstantInt::get(P->getType(), APInt::getLowBitsSet(32, IC)); 1545 1546 if (PV.IterCount != 32) 1547 P = B.CreateAnd(P, BMI); 1548 1549 if (PV.Inv) { 1550 auto *QI = dyn_cast<ConstantInt>(PV.Q); 1551 assert(QI && QI->getBitWidth() <= 32); 1552 1553 // Again, clearing bits beyond IterCount. 1554 unsigned M = (1 << PV.IterCount) - 1; 1555 unsigned Tmp = (QI->getZExtValue() | 1) & M; 1556 unsigned QV = getInverseMxN(Tmp) & M; 1557 auto *QVI = ConstantInt::get(QI->getType(), QV); 1558 P = B.CreateCall(PMF, {P, QVI}); 1559 P = B.CreateTrunc(P, QI->getType()); 1560 if (IC != 32) 1561 P = B.CreateAnd(P, BMI); 1562 } 1563 1564 Value *R = B.CreateCall(PMF, {P, Q}); 1565 1566 if (PV.M != nullptr) 1567 R = B.CreateXor(R, B.CreateIntCast(P0, R->getType(), false)); 1568 1569 return R; 1570 } 1571 1572 static bool hasZeroSignBit(const Value *V) { 1573 if (const auto *CI = dyn_cast<const ConstantInt>(V)) 1574 return CI->getValue().isNonNegative(); 1575 const Instruction *I = dyn_cast<const Instruction>(V); 1576 if (!I) 1577 return false; 1578 switch (I->getOpcode()) { 1579 case Instruction::LShr: 1580 if (const auto SI = dyn_cast<const ConstantInt>(I->getOperand(1))) 1581 return SI->getZExtValue() > 0; 1582 return false; 1583 case Instruction::Or: 1584 case Instruction::Xor: 1585 return hasZeroSignBit(I->getOperand(0)) && 1586 hasZeroSignBit(I->getOperand(1)); 1587 case Instruction::And: 1588 return hasZeroSignBit(I->getOperand(0)) || 1589 hasZeroSignBit(I->getOperand(1)); 1590 } 1591 return false; 1592 } 1593 1594 void PolynomialMultiplyRecognize::setupPreSimplifier(Simplifier &S) { 1595 S.addRule("sink-zext", 1596 // Sink zext past bitwise operations. 1597 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1598 if (I->getOpcode() != Instruction::ZExt) 1599 return nullptr; 1600 Instruction *T = dyn_cast<Instruction>(I->getOperand(0)); 1601 if (!T) 1602 return nullptr; 1603 switch (T->getOpcode()) { 1604 case Instruction::And: 1605 case Instruction::Or: 1606 case Instruction::Xor: 1607 break; 1608 default: 1609 return nullptr; 1610 } 1611 IRBuilder<> B(Ctx); 1612 return B.CreateBinOp(cast<BinaryOperator>(T)->getOpcode(), 1613 B.CreateZExt(T->getOperand(0), I->getType()), 1614 B.CreateZExt(T->getOperand(1), I->getType())); 1615 }); 1616 S.addRule("xor/and -> and/xor", 1617 // (xor (and x a) (and y a)) -> (and (xor x y) a) 1618 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1619 if (I->getOpcode() != Instruction::Xor) 1620 return nullptr; 1621 Instruction *And0 = dyn_cast<Instruction>(I->getOperand(0)); 1622 Instruction *And1 = dyn_cast<Instruction>(I->getOperand(1)); 1623 if (!And0 || !And1) 1624 return nullptr; 1625 if (And0->getOpcode() != Instruction::And || 1626 And1->getOpcode() != Instruction::And) 1627 return nullptr; 1628 if (And0->getOperand(1) != And1->getOperand(1)) 1629 return nullptr; 1630 IRBuilder<> B(Ctx); 1631 return B.CreateAnd(B.CreateXor(And0->getOperand(0), And1->getOperand(0)), 1632 And0->getOperand(1)); 1633 }); 1634 S.addRule("sink binop into select", 1635 // (Op (select c x y) z) -> (select c (Op x z) (Op y z)) 1636 // (Op x (select c y z)) -> (select c (Op x y) (Op x z)) 1637 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1638 BinaryOperator *BO = dyn_cast<BinaryOperator>(I); 1639 if (!BO) 1640 return nullptr; 1641 Instruction::BinaryOps Op = BO->getOpcode(); 1642 if (SelectInst *Sel = dyn_cast<SelectInst>(BO->getOperand(0))) { 1643 IRBuilder<> B(Ctx); 1644 Value *X = Sel->getTrueValue(), *Y = Sel->getFalseValue(); 1645 Value *Z = BO->getOperand(1); 1646 return B.CreateSelect(Sel->getCondition(), 1647 B.CreateBinOp(Op, X, Z), 1648 B.CreateBinOp(Op, Y, Z)); 1649 } 1650 if (SelectInst *Sel = dyn_cast<SelectInst>(BO->getOperand(1))) { 1651 IRBuilder<> B(Ctx); 1652 Value *X = BO->getOperand(0); 1653 Value *Y = Sel->getTrueValue(), *Z = Sel->getFalseValue(); 1654 return B.CreateSelect(Sel->getCondition(), 1655 B.CreateBinOp(Op, X, Y), 1656 B.CreateBinOp(Op, X, Z)); 1657 } 1658 return nullptr; 1659 }); 1660 S.addRule("fold select-select", 1661 // (select c (select c x y) z) -> (select c x z) 1662 // (select c x (select c y z)) -> (select c x z) 1663 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1664 SelectInst *Sel = dyn_cast<SelectInst>(I); 1665 if (!Sel) 1666 return nullptr; 1667 IRBuilder<> B(Ctx); 1668 Value *C = Sel->getCondition(); 1669 if (SelectInst *Sel0 = dyn_cast<SelectInst>(Sel->getTrueValue())) { 1670 if (Sel0->getCondition() == C) 1671 return B.CreateSelect(C, Sel0->getTrueValue(), Sel->getFalseValue()); 1672 } 1673 if (SelectInst *Sel1 = dyn_cast<SelectInst>(Sel->getFalseValue())) { 1674 if (Sel1->getCondition() == C) 1675 return B.CreateSelect(C, Sel->getTrueValue(), Sel1->getFalseValue()); 1676 } 1677 return nullptr; 1678 }); 1679 S.addRule("or-signbit -> xor-signbit", 1680 // (or (lshr x 1) 0x800.0) -> (xor (lshr x 1) 0x800.0) 1681 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1682 if (I->getOpcode() != Instruction::Or) 1683 return nullptr; 1684 ConstantInt *Msb = dyn_cast<ConstantInt>(I->getOperand(1)); 1685 if (!Msb || !Msb->getValue().isSignMask()) 1686 return nullptr; 1687 if (!hasZeroSignBit(I->getOperand(0))) 1688 return nullptr; 1689 return IRBuilder<>(Ctx).CreateXor(I->getOperand(0), Msb); 1690 }); 1691 S.addRule("sink lshr into binop", 1692 // (lshr (BitOp x y) c) -> (BitOp (lshr x c) (lshr y c)) 1693 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1694 if (I->getOpcode() != Instruction::LShr) 1695 return nullptr; 1696 BinaryOperator *BitOp = dyn_cast<BinaryOperator>(I->getOperand(0)); 1697 if (!BitOp) 1698 return nullptr; 1699 switch (BitOp->getOpcode()) { 1700 case Instruction::And: 1701 case Instruction::Or: 1702 case Instruction::Xor: 1703 break; 1704 default: 1705 return nullptr; 1706 } 1707 IRBuilder<> B(Ctx); 1708 Value *S = I->getOperand(1); 1709 return B.CreateBinOp(BitOp->getOpcode(), 1710 B.CreateLShr(BitOp->getOperand(0), S), 1711 B.CreateLShr(BitOp->getOperand(1), S)); 1712 }); 1713 S.addRule("expose bitop-const", 1714 // (BitOp1 (BitOp2 x a) b) -> (BitOp2 x (BitOp1 a b)) 1715 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1716 auto IsBitOp = [](unsigned Op) -> bool { 1717 switch (Op) { 1718 case Instruction::And: 1719 case Instruction::Or: 1720 case Instruction::Xor: 1721 return true; 1722 } 1723 return false; 1724 }; 1725 BinaryOperator *BitOp1 = dyn_cast<BinaryOperator>(I); 1726 if (!BitOp1 || !IsBitOp(BitOp1->getOpcode())) 1727 return nullptr; 1728 BinaryOperator *BitOp2 = dyn_cast<BinaryOperator>(BitOp1->getOperand(0)); 1729 if (!BitOp2 || !IsBitOp(BitOp2->getOpcode())) 1730 return nullptr; 1731 ConstantInt *CA = dyn_cast<ConstantInt>(BitOp2->getOperand(1)); 1732 ConstantInt *CB = dyn_cast<ConstantInt>(BitOp1->getOperand(1)); 1733 if (!CA || !CB) 1734 return nullptr; 1735 IRBuilder<> B(Ctx); 1736 Value *X = BitOp2->getOperand(0); 1737 return B.CreateBinOp(BitOp2->getOpcode(), X, 1738 B.CreateBinOp(BitOp1->getOpcode(), CA, CB)); 1739 }); 1740 } 1741 1742 void PolynomialMultiplyRecognize::setupPostSimplifier(Simplifier &S) { 1743 S.addRule("(and (xor (and x a) y) b) -> (and (xor x y) b), if b == b&a", 1744 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1745 if (I->getOpcode() != Instruction::And) 1746 return nullptr; 1747 Instruction *Xor = dyn_cast<Instruction>(I->getOperand(0)); 1748 ConstantInt *C0 = dyn_cast<ConstantInt>(I->getOperand(1)); 1749 if (!Xor || !C0) 1750 return nullptr; 1751 if (Xor->getOpcode() != Instruction::Xor) 1752 return nullptr; 1753 Instruction *And0 = dyn_cast<Instruction>(Xor->getOperand(0)); 1754 Instruction *And1 = dyn_cast<Instruction>(Xor->getOperand(1)); 1755 // Pick the first non-null and. 1756 if (!And0 || And0->getOpcode() != Instruction::And) 1757 std::swap(And0, And1); 1758 ConstantInt *C1 = dyn_cast<ConstantInt>(And0->getOperand(1)); 1759 if (!C1) 1760 return nullptr; 1761 uint32_t V0 = C0->getZExtValue(); 1762 uint32_t V1 = C1->getZExtValue(); 1763 if (V0 != (V0 & V1)) 1764 return nullptr; 1765 IRBuilder<> B(Ctx); 1766 return B.CreateAnd(B.CreateXor(And0->getOperand(0), And1), C0); 1767 }); 1768 } 1769 1770 bool PolynomialMultiplyRecognize::recognize() { 1771 LLVM_DEBUG(dbgs() << "Starting PolynomialMultiplyRecognize on loop\n" 1772 << *CurLoop << '\n'); 1773 // Restrictions: 1774 // - The loop must consist of a single block. 1775 // - The iteration count must be known at compile-time. 1776 // - The loop must have an induction variable starting from 0, and 1777 // incremented in each iteration of the loop. 1778 BasicBlock *LoopB = CurLoop->getHeader(); 1779 LLVM_DEBUG(dbgs() << "Loop header:\n" << *LoopB); 1780 1781 if (LoopB != CurLoop->getLoopLatch()) 1782 return false; 1783 BasicBlock *ExitB = CurLoop->getExitBlock(); 1784 if (ExitB == nullptr) 1785 return false; 1786 BasicBlock *EntryB = CurLoop->getLoopPreheader(); 1787 if (EntryB == nullptr) 1788 return false; 1789 1790 unsigned IterCount = 0; 1791 const SCEV *CT = SE.getBackedgeTakenCount(CurLoop); 1792 if (isa<SCEVCouldNotCompute>(CT)) 1793 return false; 1794 if (auto *CV = dyn_cast<SCEVConstant>(CT)) 1795 IterCount = CV->getValue()->getZExtValue() + 1; 1796 1797 Value *CIV = getCountIV(LoopB); 1798 ParsedValues PV; 1799 Simplifier PreSimp; 1800 PV.IterCount = IterCount; 1801 LLVM_DEBUG(dbgs() << "Loop IV: " << *CIV << "\nIterCount: " << IterCount 1802 << '\n'); 1803 1804 setupPreSimplifier(PreSimp); 1805 1806 // Perform a preliminary scan of select instructions to see if any of them 1807 // looks like a generator of the polynomial multiply steps. Assume that a 1808 // loop can only contain a single transformable operation, so stop the 1809 // traversal after the first reasonable candidate was found. 1810 // XXX: Currently this approach can modify the loop before being 100% sure 1811 // that the transformation can be carried out. 1812 bool FoundPreScan = false; 1813 auto FeedsPHI = [LoopB](const Value *V) -> bool { 1814 for (const Value *U : V->users()) { 1815 if (const auto *P = dyn_cast<const PHINode>(U)) 1816 if (P->getParent() == LoopB) 1817 return true; 1818 } 1819 return false; 1820 }; 1821 for (Instruction &In : *LoopB) { 1822 SelectInst *SI = dyn_cast<SelectInst>(&In); 1823 if (!SI || !FeedsPHI(SI)) 1824 continue; 1825 1826 Simplifier::Context C(SI); 1827 Value *T = PreSimp.simplify(C); 1828 SelectInst *SelI = (T && isa<SelectInst>(T)) ? cast<SelectInst>(T) : SI; 1829 LLVM_DEBUG(dbgs() << "scanSelect(pre-scan): " << PE(C, SelI) << '\n'); 1830 if (scanSelect(SelI, LoopB, EntryB, CIV, PV, true)) { 1831 FoundPreScan = true; 1832 if (SelI != SI) { 1833 Value *NewSel = C.materialize(LoopB, SI->getIterator()); 1834 SI->replaceAllUsesWith(NewSel); 1835 RecursivelyDeleteTriviallyDeadInstructions(SI, &TLI); 1836 } 1837 break; 1838 } 1839 } 1840 1841 if (!FoundPreScan) { 1842 LLVM_DEBUG(dbgs() << "Have not found candidates for pmpy\n"); 1843 return false; 1844 } 1845 1846 if (!PV.Left) { 1847 // The right shift version actually only returns the higher bits of 1848 // the result (each iteration discards the LSB). If we want to convert it 1849 // to a left-shifting loop, the working data type must be at least as 1850 // wide as the target's pmpy instruction. 1851 if (!promoteTypes(LoopB, ExitB)) 1852 return false; 1853 // Run post-promotion simplifications. 1854 Simplifier PostSimp; 1855 setupPostSimplifier(PostSimp); 1856 for (Instruction &In : *LoopB) { 1857 SelectInst *SI = dyn_cast<SelectInst>(&In); 1858 if (!SI || !FeedsPHI(SI)) 1859 continue; 1860 Simplifier::Context C(SI); 1861 Value *T = PostSimp.simplify(C); 1862 SelectInst *SelI = dyn_cast_or_null<SelectInst>(T); 1863 if (SelI != SI) { 1864 Value *NewSel = C.materialize(LoopB, SI->getIterator()); 1865 SI->replaceAllUsesWith(NewSel); 1866 RecursivelyDeleteTriviallyDeadInstructions(SI, &TLI); 1867 } 1868 break; 1869 } 1870 1871 if (!convertShiftsToLeft(LoopB, ExitB, IterCount)) 1872 return false; 1873 cleanupLoopBody(LoopB); 1874 } 1875 1876 // Scan the loop again, find the generating select instruction. 1877 bool FoundScan = false; 1878 for (Instruction &In : *LoopB) { 1879 SelectInst *SelI = dyn_cast<SelectInst>(&In); 1880 if (!SelI) 1881 continue; 1882 LLVM_DEBUG(dbgs() << "scanSelect: " << *SelI << '\n'); 1883 FoundScan = scanSelect(SelI, LoopB, EntryB, CIV, PV, false); 1884 if (FoundScan) 1885 break; 1886 } 1887 assert(FoundScan); 1888 1889 LLVM_DEBUG({ 1890 StringRef PP = (PV.M ? "(P+M)" : "P"); 1891 if (!PV.Inv) 1892 dbgs() << "Found pmpy idiom: R = " << PP << ".Q\n"; 1893 else 1894 dbgs() << "Found inverse pmpy idiom: R = (" << PP << "/Q).Q) + " 1895 << PP << "\n"; 1896 dbgs() << " Res:" << *PV.Res << "\n P:" << *PV.P << "\n"; 1897 if (PV.M) 1898 dbgs() << " M:" << *PV.M << "\n"; 1899 dbgs() << " Q:" << *PV.Q << "\n"; 1900 dbgs() << " Iteration count:" << PV.IterCount << "\n"; 1901 }); 1902 1903 BasicBlock::iterator At(EntryB->getTerminator()); 1904 Value *PM = generate(At, PV); 1905 if (PM == nullptr) 1906 return false; 1907 1908 if (PM->getType() != PV.Res->getType()) 1909 PM = IRBuilder<>(&*At).CreateIntCast(PM, PV.Res->getType(), false); 1910 1911 PV.Res->replaceAllUsesWith(PM); 1912 PV.Res->eraseFromParent(); 1913 return true; 1914 } 1915 1916 int HexagonLoopIdiomRecognize::getSCEVStride(const SCEVAddRecExpr *S) { 1917 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getOperand(1))) 1918 return SC->getAPInt().getSExtValue(); 1919 return 0; 1920 } 1921 1922 bool HexagonLoopIdiomRecognize::isLegalStore(Loop *CurLoop, StoreInst *SI) { 1923 // Allow volatile stores if HexagonVolatileMemcpy is enabled. 1924 if (!(SI->isVolatile() && HexagonVolatileMemcpy) && !SI->isSimple()) 1925 return false; 1926 1927 Value *StoredVal = SI->getValueOperand(); 1928 Value *StorePtr = SI->getPointerOperand(); 1929 1930 // Reject stores that are so large that they overflow an unsigned. 1931 uint64_t SizeInBits = DL->getTypeSizeInBits(StoredVal->getType()); 1932 if ((SizeInBits & 7) || (SizeInBits >> 32) != 0) 1933 return false; 1934 1935 // See if the pointer expression is an AddRec like {base,+,1} on the current 1936 // loop, which indicates a strided store. If we have something else, it's a 1937 // random store we can't handle. 1938 auto *StoreEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); 1939 if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine()) 1940 return false; 1941 1942 // Check to see if the stride matches the size of the store. If so, then we 1943 // know that every byte is touched in the loop. 1944 int Stride = getSCEVStride(StoreEv); 1945 if (Stride == 0) 1946 return false; 1947 unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType()); 1948 if (StoreSize != unsigned(std::abs(Stride))) 1949 return false; 1950 1951 // The store must be feeding a non-volatile load. 1952 LoadInst *LI = dyn_cast<LoadInst>(SI->getValueOperand()); 1953 if (!LI || !LI->isSimple()) 1954 return false; 1955 1956 // See if the pointer expression is an AddRec like {base,+,1} on the current 1957 // loop, which indicates a strided load. If we have something else, it's a 1958 // random load we can't handle. 1959 Value *LoadPtr = LI->getPointerOperand(); 1960 auto *LoadEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(LoadPtr)); 1961 if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine()) 1962 return false; 1963 1964 // The store and load must share the same stride. 1965 if (StoreEv->getOperand(1) != LoadEv->getOperand(1)) 1966 return false; 1967 1968 // Success. This store can be converted into a memcpy. 1969 return true; 1970 } 1971 1972 /// mayLoopAccessLocation - Return true if the specified loop might access the 1973 /// specified pointer location, which is a loop-strided access. The 'Access' 1974 /// argument specifies what the verboten forms of access are (read or write). 1975 static bool 1976 mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, 1977 const SCEV *BECount, unsigned StoreSize, 1978 AliasAnalysis &AA, 1979 SmallPtrSetImpl<Instruction *> &Ignored) { 1980 // Get the location that may be stored across the loop. Since the access 1981 // is strided positively through memory, we say that the modified location 1982 // starts at the pointer and has infinite size. 1983 LocationSize AccessSize = LocationSize::afterPointer(); 1984 1985 // If the loop iterates a fixed number of times, we can refine the access 1986 // size to be exactly the size of the memset, which is (BECount+1)*StoreSize 1987 if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount)) 1988 AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) * 1989 StoreSize); 1990 1991 // TODO: For this to be really effective, we have to dive into the pointer 1992 // operand in the store. Store to &A[i] of 100 will always return may alias 1993 // with store of &A[100], we need to StoreLoc to be "A" with size of 100, 1994 // which will then no-alias a store to &A[100]. 1995 MemoryLocation StoreLoc(Ptr, AccessSize); 1996 1997 for (auto *B : L->blocks()) 1998 for (auto &I : *B) 1999 if (Ignored.count(&I) == 0 && 2000 isModOrRefSet(AA.getModRefInfo(&I, StoreLoc) & Access)) 2001 return true; 2002 2003 return false; 2004 } 2005 2006 void HexagonLoopIdiomRecognize::collectStores(Loop *CurLoop, BasicBlock *BB, 2007 SmallVectorImpl<StoreInst*> &Stores) { 2008 Stores.clear(); 2009 for (Instruction &I : *BB) 2010 if (StoreInst *SI = dyn_cast<StoreInst>(&I)) 2011 if (isLegalStore(CurLoop, SI)) 2012 Stores.push_back(SI); 2013 } 2014 2015 bool HexagonLoopIdiomRecognize::processCopyingStore(Loop *CurLoop, 2016 StoreInst *SI, const SCEV *BECount) { 2017 assert((SI->isSimple() || (SI->isVolatile() && HexagonVolatileMemcpy)) && 2018 "Expected only non-volatile stores, or Hexagon-specific memcpy" 2019 "to volatile destination."); 2020 2021 Value *StorePtr = SI->getPointerOperand(); 2022 auto *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); 2023 unsigned Stride = getSCEVStride(StoreEv); 2024 unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType()); 2025 if (Stride != StoreSize) 2026 return false; 2027 2028 // See if the pointer expression is an AddRec like {base,+,1} on the current 2029 // loop, which indicates a strided load. If we have something else, it's a 2030 // random load we can't handle. 2031 auto *LI = cast<LoadInst>(SI->getValueOperand()); 2032 auto *LoadEv = cast<SCEVAddRecExpr>(SE->getSCEV(LI->getPointerOperand())); 2033 2034 // The trip count of the loop and the base pointer of the addrec SCEV is 2035 // guaranteed to be loop invariant, which means that it should dominate the 2036 // header. This allows us to insert code for it in the preheader. 2037 BasicBlock *Preheader = CurLoop->getLoopPreheader(); 2038 Instruction *ExpPt = Preheader->getTerminator(); 2039 IRBuilder<> Builder(ExpPt); 2040 SCEVExpander Expander(*SE, *DL, "hexagon-loop-idiom"); 2041 2042 Type *IntPtrTy = Builder.getIntPtrTy(*DL, SI->getPointerAddressSpace()); 2043 2044 // Okay, we have a strided store "p[i]" of a loaded value. We can turn 2045 // this into a memcpy/memmove in the loop preheader now if we want. However, 2046 // this would be unsafe to do if there is anything else in the loop that may 2047 // read or write the memory region we're storing to. For memcpy, this 2048 // includes the load that feeds the stores. Check for an alias by generating 2049 // the base address and checking everything. 2050 Value *StoreBasePtr = Expander.expandCodeFor(StoreEv->getStart(), 2051 Builder.getPtrTy(SI->getPointerAddressSpace()), ExpPt); 2052 Value *LoadBasePtr = nullptr; 2053 2054 bool Overlap = false; 2055 bool DestVolatile = SI->isVolatile(); 2056 Type *BECountTy = BECount->getType(); 2057 2058 if (DestVolatile) { 2059 // The trip count must fit in i32, since it is the type of the "num_words" 2060 // argument to hexagon_memcpy_forward_vp4cp4n2. 2061 if (StoreSize != 4 || DL->getTypeSizeInBits(BECountTy) > 32) { 2062 CleanupAndExit: 2063 // If we generated new code for the base pointer, clean up. 2064 Expander.clear(); 2065 if (StoreBasePtr && (LoadBasePtr != StoreBasePtr)) { 2066 RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI); 2067 StoreBasePtr = nullptr; 2068 } 2069 if (LoadBasePtr) { 2070 RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr, TLI); 2071 LoadBasePtr = nullptr; 2072 } 2073 return false; 2074 } 2075 } 2076 2077 SmallPtrSet<Instruction*, 2> Ignore1; 2078 Ignore1.insert(SI); 2079 if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount, 2080 StoreSize, *AA, Ignore1)) { 2081 // Check if the load is the offending instruction. 2082 Ignore1.insert(LI); 2083 if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, 2084 BECount, StoreSize, *AA, Ignore1)) { 2085 // Still bad. Nothing we can do. 2086 goto CleanupAndExit; 2087 } 2088 // It worked with the load ignored. 2089 Overlap = true; 2090 } 2091 2092 if (!Overlap) { 2093 if (DisableMemcpyIdiom || !HasMemcpy) 2094 goto CleanupAndExit; 2095 } else { 2096 // Don't generate memmove if this function will be inlined. This is 2097 // because the caller will undergo this transformation after inlining. 2098 Function *Func = CurLoop->getHeader()->getParent(); 2099 if (Func->hasFnAttribute(Attribute::AlwaysInline)) 2100 goto CleanupAndExit; 2101 2102 // In case of a memmove, the call to memmove will be executed instead 2103 // of the loop, so we need to make sure that there is nothing else in 2104 // the loop than the load, store and instructions that these two depend 2105 // on. 2106 SmallVector<Instruction*,2> Insts; 2107 Insts.push_back(SI); 2108 Insts.push_back(LI); 2109 if (!coverLoop(CurLoop, Insts)) 2110 goto CleanupAndExit; 2111 2112 if (DisableMemmoveIdiom || !HasMemmove) 2113 goto CleanupAndExit; 2114 bool IsNested = CurLoop->getParentLoop() != nullptr; 2115 if (IsNested && OnlyNonNestedMemmove) 2116 goto CleanupAndExit; 2117 } 2118 2119 // For a memcpy, we have to make sure that the input array is not being 2120 // mutated by the loop. 2121 LoadBasePtr = Expander.expandCodeFor(LoadEv->getStart(), 2122 Builder.getPtrTy(LI->getPointerAddressSpace()), ExpPt); 2123 2124 SmallPtrSet<Instruction*, 2> Ignore2; 2125 Ignore2.insert(SI); 2126 if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount, 2127 StoreSize, *AA, Ignore2)) 2128 goto CleanupAndExit; 2129 2130 // Check the stride. 2131 bool StridePos = getSCEVStride(LoadEv) >= 0; 2132 2133 // Currently, the volatile memcpy only emulates traversing memory forward. 2134 if (!StridePos && DestVolatile) 2135 goto CleanupAndExit; 2136 2137 bool RuntimeCheck = (Overlap || DestVolatile); 2138 2139 BasicBlock *ExitB; 2140 if (RuntimeCheck) { 2141 // The runtime check needs a single exit block. 2142 SmallVector<BasicBlock*, 8> ExitBlocks; 2143 CurLoop->getUniqueExitBlocks(ExitBlocks); 2144 if (ExitBlocks.size() != 1) 2145 goto CleanupAndExit; 2146 ExitB = ExitBlocks[0]; 2147 } 2148 2149 // The # stored bytes is (BECount+1)*Size. Expand the trip count out to 2150 // pointer size if it isn't already. 2151 LLVMContext &Ctx = SI->getContext(); 2152 BECount = SE->getTruncateOrZeroExtend(BECount, IntPtrTy); 2153 DebugLoc DLoc = SI->getDebugLoc(); 2154 2155 const SCEV *NumBytesS = 2156 SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW); 2157 if (StoreSize != 1) 2158 NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize), 2159 SCEV::FlagNUW); 2160 Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntPtrTy, ExpPt); 2161 if (Instruction *In = dyn_cast<Instruction>(NumBytes)) 2162 if (Value *Simp = simplifyInstruction(In, {*DL, TLI, DT})) 2163 NumBytes = Simp; 2164 2165 CallInst *NewCall; 2166 2167 if (RuntimeCheck) { 2168 unsigned Threshold = RuntimeMemSizeThreshold; 2169 if (ConstantInt *CI = dyn_cast<ConstantInt>(NumBytes)) { 2170 uint64_t C = CI->getZExtValue(); 2171 if (Threshold != 0 && C < Threshold) 2172 goto CleanupAndExit; 2173 if (C < CompileTimeMemSizeThreshold) 2174 goto CleanupAndExit; 2175 } 2176 2177 BasicBlock *Header = CurLoop->getHeader(); 2178 Function *Func = Header->getParent(); 2179 Loop *ParentL = LF->getLoopFor(Preheader); 2180 StringRef HeaderName = Header->getName(); 2181 2182 // Create a new (empty) preheader, and update the PHI nodes in the 2183 // header to use the new preheader. 2184 BasicBlock *NewPreheader = BasicBlock::Create(Ctx, HeaderName+".rtli.ph", 2185 Func, Header); 2186 if (ParentL) 2187 ParentL->addBasicBlockToLoop(NewPreheader, *LF); 2188 IRBuilder<>(NewPreheader).CreateBr(Header); 2189 for (auto &In : *Header) { 2190 PHINode *PN = dyn_cast<PHINode>(&In); 2191 if (!PN) 2192 break; 2193 int bx = PN->getBasicBlockIndex(Preheader); 2194 if (bx >= 0) 2195 PN->setIncomingBlock(bx, NewPreheader); 2196 } 2197 DT->addNewBlock(NewPreheader, Preheader); 2198 DT->changeImmediateDominator(Header, NewPreheader); 2199 2200 // Check for safe conditions to execute memmove. 2201 // If stride is positive, copying things from higher to lower addresses 2202 // is equivalent to memmove. For negative stride, it's the other way 2203 // around. Copying forward in memory with positive stride may not be 2204 // same as memmove since we may be copying values that we just stored 2205 // in some previous iteration. 2206 Value *LA = Builder.CreatePtrToInt(LoadBasePtr, IntPtrTy); 2207 Value *SA = Builder.CreatePtrToInt(StoreBasePtr, IntPtrTy); 2208 Value *LowA = StridePos ? SA : LA; 2209 Value *HighA = StridePos ? LA : SA; 2210 Value *CmpA = Builder.CreateICmpULT(LowA, HighA); 2211 Value *Cond = CmpA; 2212 2213 // Check for distance between pointers. Since the case LowA < HighA 2214 // is checked for above, assume LowA >= HighA. 2215 Value *Dist = Builder.CreateSub(LowA, HighA); 2216 Value *CmpD = Builder.CreateICmpSLE(NumBytes, Dist); 2217 Value *CmpEither = Builder.CreateOr(Cond, CmpD); 2218 Cond = CmpEither; 2219 2220 if (Threshold != 0) { 2221 Type *Ty = NumBytes->getType(); 2222 Value *Thr = ConstantInt::get(Ty, Threshold); 2223 Value *CmpB = Builder.CreateICmpULT(Thr, NumBytes); 2224 Value *CmpBoth = Builder.CreateAnd(Cond, CmpB); 2225 Cond = CmpBoth; 2226 } 2227 BasicBlock *MemmoveB = BasicBlock::Create(Ctx, Header->getName()+".rtli", 2228 Func, NewPreheader); 2229 if (ParentL) 2230 ParentL->addBasicBlockToLoop(MemmoveB, *LF); 2231 Instruction *OldT = Preheader->getTerminator(); 2232 Builder.CreateCondBr(Cond, MemmoveB, NewPreheader); 2233 OldT->eraseFromParent(); 2234 Preheader->setName(Preheader->getName()+".old"); 2235 DT->addNewBlock(MemmoveB, Preheader); 2236 // Find the new immediate dominator of the exit block. 2237 BasicBlock *ExitD = Preheader; 2238 for (BasicBlock *PB : predecessors(ExitB)) { 2239 ExitD = DT->findNearestCommonDominator(ExitD, PB); 2240 if (!ExitD) 2241 break; 2242 } 2243 // If the prior immediate dominator of ExitB was dominated by the 2244 // old preheader, then the old preheader becomes the new immediate 2245 // dominator. Otherwise don't change anything (because the newly 2246 // added blocks are dominated by the old preheader). 2247 if (ExitD && DT->dominates(Preheader, ExitD)) { 2248 DomTreeNode *BN = DT->getNode(ExitB); 2249 DomTreeNode *DN = DT->getNode(ExitD); 2250 BN->setIDom(DN); 2251 } 2252 2253 // Add a call to memmove to the conditional block. 2254 IRBuilder<> CondBuilder(MemmoveB); 2255 CondBuilder.CreateBr(ExitB); 2256 CondBuilder.SetInsertPoint(MemmoveB->getTerminator()); 2257 2258 if (DestVolatile) { 2259 Type *Int32Ty = Type::getInt32Ty(Ctx); 2260 Type *PtrTy = PointerType::get(Ctx, 0); 2261 Type *VoidTy = Type::getVoidTy(Ctx); 2262 Module *M = Func->getParent(); 2263 FunctionCallee Fn = M->getOrInsertFunction( 2264 HexagonVolatileMemcpyName, VoidTy, PtrTy, PtrTy, Int32Ty); 2265 2266 const SCEV *OneS = SE->getConstant(Int32Ty, 1); 2267 const SCEV *BECount32 = SE->getTruncateOrZeroExtend(BECount, Int32Ty); 2268 const SCEV *NumWordsS = SE->getAddExpr(BECount32, OneS, SCEV::FlagNUW); 2269 Value *NumWords = Expander.expandCodeFor(NumWordsS, Int32Ty, 2270 MemmoveB->getTerminator()); 2271 if (Instruction *In = dyn_cast<Instruction>(NumWords)) 2272 if (Value *Simp = simplifyInstruction(In, {*DL, TLI, DT})) 2273 NumWords = Simp; 2274 2275 NewCall = CondBuilder.CreateCall(Fn, 2276 {StoreBasePtr, LoadBasePtr, NumWords}); 2277 } else { 2278 NewCall = CondBuilder.CreateMemMove( 2279 StoreBasePtr, SI->getAlign(), LoadBasePtr, LI->getAlign(), NumBytes); 2280 } 2281 } else { 2282 NewCall = Builder.CreateMemCpy(StoreBasePtr, SI->getAlign(), LoadBasePtr, 2283 LI->getAlign(), NumBytes); 2284 // Okay, the memcpy has been formed. Zap the original store and 2285 // anything that feeds into it. 2286 RecursivelyDeleteTriviallyDeadInstructions(SI, TLI); 2287 } 2288 2289 NewCall->setDebugLoc(DLoc); 2290 2291 LLVM_DEBUG(dbgs() << " Formed " << (Overlap ? "memmove: " : "memcpy: ") 2292 << *NewCall << "\n" 2293 << " from load ptr=" << *LoadEv << " at: " << *LI << "\n" 2294 << " from store ptr=" << *StoreEv << " at: " << *SI 2295 << "\n"); 2296 2297 return true; 2298 } 2299 2300 // Check if the instructions in Insts, together with their dependencies 2301 // cover the loop in the sense that the loop could be safely eliminated once 2302 // the instructions in Insts are removed. 2303 bool HexagonLoopIdiomRecognize::coverLoop(Loop *L, 2304 SmallVectorImpl<Instruction*> &Insts) const { 2305 SmallSet<BasicBlock*,8> LoopBlocks; 2306 for (auto *B : L->blocks()) 2307 LoopBlocks.insert(B); 2308 2309 SetVector<Instruction*> Worklist(Insts.begin(), Insts.end()); 2310 2311 // Collect all instructions from the loop that the instructions in Insts 2312 // depend on (plus their dependencies, etc.). These instructions will 2313 // constitute the expression trees that feed those in Insts, but the trees 2314 // will be limited only to instructions contained in the loop. 2315 for (unsigned i = 0; i < Worklist.size(); ++i) { 2316 Instruction *In = Worklist[i]; 2317 for (auto I = In->op_begin(), E = In->op_end(); I != E; ++I) { 2318 Instruction *OpI = dyn_cast<Instruction>(I); 2319 if (!OpI) 2320 continue; 2321 BasicBlock *PB = OpI->getParent(); 2322 if (!LoopBlocks.count(PB)) 2323 continue; 2324 Worklist.insert(OpI); 2325 } 2326 } 2327 2328 // Scan all instructions in the loop, if any of them have a user outside 2329 // of the loop, or outside of the expressions collected above, then either 2330 // the loop has a side-effect visible outside of it, or there are 2331 // instructions in it that are not involved in the original set Insts. 2332 for (auto *B : L->blocks()) { 2333 for (auto &In : *B) { 2334 if (isa<BranchInst>(In) || isa<DbgInfoIntrinsic>(In)) 2335 continue; 2336 if (!Worklist.count(&In) && In.mayHaveSideEffects()) 2337 return false; 2338 for (auto *K : In.users()) { 2339 Instruction *UseI = dyn_cast<Instruction>(K); 2340 if (!UseI) 2341 continue; 2342 BasicBlock *UseB = UseI->getParent(); 2343 if (LF->getLoopFor(UseB) != L) 2344 return false; 2345 } 2346 } 2347 } 2348 2349 return true; 2350 } 2351 2352 /// runOnLoopBlock - Process the specified block, which lives in a counted loop 2353 /// with the specified backedge count. This block is known to be in the current 2354 /// loop and not in any subloops. 2355 bool HexagonLoopIdiomRecognize::runOnLoopBlock(Loop *CurLoop, BasicBlock *BB, 2356 const SCEV *BECount, SmallVectorImpl<BasicBlock*> &ExitBlocks) { 2357 // We can only promote stores in this block if they are unconditionally 2358 // executed in the loop. For a block to be unconditionally executed, it has 2359 // to dominate all the exit blocks of the loop. Verify this now. 2360 auto DominatedByBB = [this,BB] (BasicBlock *EB) -> bool { 2361 return DT->dominates(BB, EB); 2362 }; 2363 if (!all_of(ExitBlocks, DominatedByBB)) 2364 return false; 2365 2366 bool MadeChange = false; 2367 // Look for store instructions, which may be optimized to memset/memcpy. 2368 SmallVector<StoreInst*,8> Stores; 2369 collectStores(CurLoop, BB, Stores); 2370 2371 // Optimize the store into a memcpy, if it feeds an similarly strided load. 2372 for (auto &SI : Stores) 2373 MadeChange |= processCopyingStore(CurLoop, SI, BECount); 2374 2375 return MadeChange; 2376 } 2377 2378 bool HexagonLoopIdiomRecognize::runOnCountableLoop(Loop *L) { 2379 PolynomialMultiplyRecognize PMR(L, *DL, *DT, *TLI, *SE); 2380 if (PMR.recognize()) 2381 return true; 2382 2383 if (!HasMemcpy && !HasMemmove) 2384 return false; 2385 2386 const SCEV *BECount = SE->getBackedgeTakenCount(L); 2387 assert(!isa<SCEVCouldNotCompute>(BECount) && 2388 "runOnCountableLoop() called on a loop without a predictable" 2389 "backedge-taken count"); 2390 2391 SmallVector<BasicBlock *, 8> ExitBlocks; 2392 L->getUniqueExitBlocks(ExitBlocks); 2393 2394 bool Changed = false; 2395 2396 // Scan all the blocks in the loop that are not in subloops. 2397 for (auto *BB : L->getBlocks()) { 2398 // Ignore blocks in subloops. 2399 if (LF->getLoopFor(BB) != L) 2400 continue; 2401 Changed |= runOnLoopBlock(L, BB, BECount, ExitBlocks); 2402 } 2403 2404 return Changed; 2405 } 2406 2407 bool HexagonLoopIdiomRecognize::run(Loop *L) { 2408 const Module &M = *L->getHeader()->getParent()->getParent(); 2409 if (Triple(M.getTargetTriple()).getArch() != Triple::hexagon) 2410 return false; 2411 2412 // If the loop could not be converted to canonical form, it must have an 2413 // indirectbr in it, just give up. 2414 if (!L->getLoopPreheader()) 2415 return false; 2416 2417 // Disable loop idiom recognition if the function's name is a common idiom. 2418 StringRef Name = L->getHeader()->getParent()->getName(); 2419 if (Name == "memset" || Name == "memcpy" || Name == "memmove") 2420 return false; 2421 2422 DL = &L->getHeader()->getDataLayout(); 2423 2424 HasMemcpy = TLI->has(LibFunc_memcpy); 2425 HasMemmove = TLI->has(LibFunc_memmove); 2426 2427 if (SE->hasLoopInvariantBackedgeTakenCount(L)) 2428 return runOnCountableLoop(L); 2429 return false; 2430 } 2431 2432 bool HexagonLoopIdiomRecognizeLegacyPass::runOnLoop(Loop *L, 2433 LPPassManager &LPM) { 2434 if (skipLoop(L)) 2435 return false; 2436 2437 auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); 2438 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 2439 auto *LF = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 2440 auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( 2441 *L->getHeader()->getParent()); 2442 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 2443 return HexagonLoopIdiomRecognize(AA, DT, LF, TLI, SE).run(L); 2444 } 2445 2446 Pass *llvm::createHexagonLoopIdiomPass() { 2447 return new HexagonLoopIdiomRecognizeLegacyPass(); 2448 } 2449 2450 PreservedAnalyses 2451 HexagonLoopIdiomRecognitionPass::run(Loop &L, LoopAnalysisManager &AM, 2452 LoopStandardAnalysisResults &AR, 2453 LPMUpdater &U) { 2454 return HexagonLoopIdiomRecognize(&AR.AA, &AR.DT, &AR.LI, &AR.TLI, &AR.SE) 2455 .run(&L) 2456 ? getLoopPassPreservedAnalyses() 2457 : PreservedAnalyses::all(); 2458 } 2459