1 //===- HexagonLoopIdiomRecognition.cpp ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "HexagonLoopIdiomRecognition.h" 10 #include "llvm/ADT/APInt.h" 11 #include "llvm/ADT/DenseMap.h" 12 #include "llvm/ADT/SetVector.h" 13 #include "llvm/ADT/SmallPtrSet.h" 14 #include "llvm/ADT/SmallSet.h" 15 #include "llvm/ADT/SmallVector.h" 16 #include "llvm/ADT/StringRef.h" 17 #include "llvm/ADT/Triple.h" 18 #include "llvm/Analysis/AliasAnalysis.h" 19 #include "llvm/Analysis/InstructionSimplify.h" 20 #include "llvm/Analysis/LoopAnalysisManager.h" 21 #include "llvm/Analysis/LoopInfo.h" 22 #include "llvm/Analysis/LoopPass.h" 23 #include "llvm/Analysis/MemoryLocation.h" 24 #include "llvm/Analysis/ScalarEvolution.h" 25 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 26 #include "llvm/Analysis/TargetLibraryInfo.h" 27 #include "llvm/Analysis/ValueTracking.h" 28 #include "llvm/IR/Attributes.h" 29 #include "llvm/IR/BasicBlock.h" 30 #include "llvm/IR/Constant.h" 31 #include "llvm/IR/Constants.h" 32 #include "llvm/IR/DataLayout.h" 33 #include "llvm/IR/DebugLoc.h" 34 #include "llvm/IR/DerivedTypes.h" 35 #include "llvm/IR/Dominators.h" 36 #include "llvm/IR/Function.h" 37 #include "llvm/IR/IRBuilder.h" 38 #include "llvm/IR/InstrTypes.h" 39 #include "llvm/IR/Instruction.h" 40 #include "llvm/IR/Instructions.h" 41 #include "llvm/IR/IntrinsicInst.h" 42 #include "llvm/IR/Intrinsics.h" 43 #include "llvm/IR/IntrinsicsHexagon.h" 44 #include "llvm/IR/Module.h" 45 #include "llvm/IR/PassManager.h" 46 #include "llvm/IR/PatternMatch.h" 47 #include "llvm/IR/Type.h" 48 #include "llvm/IR/User.h" 49 #include "llvm/IR/Value.h" 50 #include "llvm/InitializePasses.h" 51 #include "llvm/Pass.h" 52 #include "llvm/Support/Casting.h" 53 #include "llvm/Support/CommandLine.h" 54 #include "llvm/Support/Compiler.h" 55 #include "llvm/Support/Debug.h" 56 #include "llvm/Support/ErrorHandling.h" 57 #include "llvm/Support/KnownBits.h" 58 #include "llvm/Support/raw_ostream.h" 59 #include "llvm/Transforms/Scalar.h" 60 #include "llvm/Transforms/Utils.h" 61 #include "llvm/Transforms/Utils/Local.h" 62 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" 63 #include <algorithm> 64 #include <array> 65 #include <cassert> 66 #include <cstdint> 67 #include <cstdlib> 68 #include <deque> 69 #include <functional> 70 #include <iterator> 71 #include <map> 72 #include <set> 73 #include <utility> 74 #include <vector> 75 76 #define DEBUG_TYPE "hexagon-lir" 77 78 using namespace llvm; 79 80 static cl::opt<bool> DisableMemcpyIdiom("disable-memcpy-idiom", 81 cl::Hidden, cl::init(false), 82 cl::desc("Disable generation of memcpy in loop idiom recognition")); 83 84 static cl::opt<bool> DisableMemmoveIdiom("disable-memmove-idiom", 85 cl::Hidden, cl::init(false), 86 cl::desc("Disable generation of memmove in loop idiom recognition")); 87 88 static cl::opt<unsigned> RuntimeMemSizeThreshold("runtime-mem-idiom-threshold", 89 cl::Hidden, cl::init(0), cl::desc("Threshold (in bytes) for the runtime " 90 "check guarding the memmove.")); 91 92 static cl::opt<unsigned> CompileTimeMemSizeThreshold( 93 "compile-time-mem-idiom-threshold", cl::Hidden, cl::init(64), 94 cl::desc("Threshold (in bytes) to perform the transformation, if the " 95 "runtime loop count (mem transfer size) is known at compile-time.")); 96 97 static cl::opt<bool> OnlyNonNestedMemmove("only-nonnested-memmove-idiom", 98 cl::Hidden, cl::init(true), 99 cl::desc("Only enable generating memmove in non-nested loops")); 100 101 static cl::opt<bool> HexagonVolatileMemcpy( 102 "disable-hexagon-volatile-memcpy", cl::Hidden, cl::init(false), 103 cl::desc("Enable Hexagon-specific memcpy for volatile destination.")); 104 105 static cl::opt<unsigned> SimplifyLimit("hlir-simplify-limit", cl::init(10000), 106 cl::Hidden, cl::desc("Maximum number of simplification steps in HLIR")); 107 108 static const char *HexagonVolatileMemcpyName 109 = "hexagon_memcpy_forward_vp4cp4n2"; 110 111 112 namespace llvm { 113 114 void initializeHexagonLoopIdiomRecognizeLegacyPassPass(PassRegistry &); 115 Pass *createHexagonLoopIdiomPass(); 116 117 } // end namespace llvm 118 119 namespace { 120 121 class HexagonLoopIdiomRecognize { 122 public: 123 explicit HexagonLoopIdiomRecognize(AliasAnalysis *AA, DominatorTree *DT, 124 LoopInfo *LF, const TargetLibraryInfo *TLI, 125 ScalarEvolution *SE) 126 : AA(AA), DT(DT), LF(LF), TLI(TLI), SE(SE) {} 127 128 bool run(Loop *L); 129 130 private: 131 int getSCEVStride(const SCEVAddRecExpr *StoreEv); 132 bool isLegalStore(Loop *CurLoop, StoreInst *SI); 133 void collectStores(Loop *CurLoop, BasicBlock *BB, 134 SmallVectorImpl<StoreInst *> &Stores); 135 bool processCopyingStore(Loop *CurLoop, StoreInst *SI, const SCEV *BECount); 136 bool coverLoop(Loop *L, SmallVectorImpl<Instruction *> &Insts) const; 137 bool runOnLoopBlock(Loop *CurLoop, BasicBlock *BB, const SCEV *BECount, 138 SmallVectorImpl<BasicBlock *> &ExitBlocks); 139 bool runOnCountableLoop(Loop *L); 140 141 AliasAnalysis *AA; 142 const DataLayout *DL; 143 DominatorTree *DT; 144 LoopInfo *LF; 145 const TargetLibraryInfo *TLI; 146 ScalarEvolution *SE; 147 bool HasMemcpy, HasMemmove; 148 }; 149 150 class HexagonLoopIdiomRecognizeLegacyPass : public LoopPass { 151 public: 152 static char ID; 153 154 explicit HexagonLoopIdiomRecognizeLegacyPass() : LoopPass(ID) { 155 initializeHexagonLoopIdiomRecognizeLegacyPassPass( 156 *PassRegistry::getPassRegistry()); 157 } 158 159 StringRef getPassName() const override { 160 return "Recognize Hexagon-specific loop idioms"; 161 } 162 163 void getAnalysisUsage(AnalysisUsage &AU) const override { 164 AU.addRequired<LoopInfoWrapperPass>(); 165 AU.addRequiredID(LoopSimplifyID); 166 AU.addRequiredID(LCSSAID); 167 AU.addRequired<AAResultsWrapperPass>(); 168 AU.addRequired<ScalarEvolutionWrapperPass>(); 169 AU.addRequired<DominatorTreeWrapperPass>(); 170 AU.addRequired<TargetLibraryInfoWrapperPass>(); 171 AU.addPreserved<TargetLibraryInfoWrapperPass>(); 172 } 173 174 bool runOnLoop(Loop *L, LPPassManager &LPM) override; 175 }; 176 177 struct Simplifier { 178 struct Rule { 179 using FuncType = std::function<Value *(Instruction *, LLVMContext &)>; 180 Rule(StringRef N, FuncType F) : Name(N), Fn(F) {} 181 StringRef Name; // For debugging. 182 FuncType Fn; 183 }; 184 185 void addRule(StringRef N, const Rule::FuncType &F) { 186 Rules.push_back(Rule(N, F)); 187 } 188 189 private: 190 struct WorkListType { 191 WorkListType() = default; 192 193 void push_back(Value *V) { 194 // Do not push back duplicates. 195 if (!S.count(V)) { 196 Q.push_back(V); 197 S.insert(V); 198 } 199 } 200 201 Value *pop_front_val() { 202 Value *V = Q.front(); 203 Q.pop_front(); 204 S.erase(V); 205 return V; 206 } 207 208 bool empty() const { return Q.empty(); } 209 210 private: 211 std::deque<Value *> Q; 212 std::set<Value *> S; 213 }; 214 215 using ValueSetType = std::set<Value *>; 216 217 std::vector<Rule> Rules; 218 219 public: 220 struct Context { 221 using ValueMapType = DenseMap<Value *, Value *>; 222 223 Value *Root; 224 ValueSetType Used; // The set of all cloned values used by Root. 225 ValueSetType Clones; // The set of all cloned values. 226 LLVMContext &Ctx; 227 228 Context(Instruction *Exp) 229 : Ctx(Exp->getParent()->getParent()->getContext()) { 230 initialize(Exp); 231 } 232 233 ~Context() { cleanup(); } 234 235 void print(raw_ostream &OS, const Value *V) const; 236 Value *materialize(BasicBlock *B, BasicBlock::iterator At); 237 238 private: 239 friend struct Simplifier; 240 241 void initialize(Instruction *Exp); 242 void cleanup(); 243 244 template <typename FuncT> void traverse(Value *V, FuncT F); 245 void record(Value *V); 246 void use(Value *V); 247 void unuse(Value *V); 248 249 bool equal(const Instruction *I, const Instruction *J) const; 250 Value *find(Value *Tree, Value *Sub) const; 251 Value *subst(Value *Tree, Value *OldV, Value *NewV); 252 void replace(Value *OldV, Value *NewV); 253 void link(Instruction *I, BasicBlock *B, BasicBlock::iterator At); 254 }; 255 256 Value *simplify(Context &C); 257 }; 258 259 struct PE { 260 PE(const Simplifier::Context &c, Value *v = nullptr) : C(c), V(v) {} 261 262 const Simplifier::Context &C; 263 const Value *V; 264 }; 265 266 LLVM_ATTRIBUTE_USED 267 raw_ostream &operator<<(raw_ostream &OS, const PE &P) { 268 P.C.print(OS, P.V ? P.V : P.C.Root); 269 return OS; 270 } 271 272 } // end anonymous namespace 273 274 char HexagonLoopIdiomRecognizeLegacyPass::ID = 0; 275 276 INITIALIZE_PASS_BEGIN(HexagonLoopIdiomRecognizeLegacyPass, "hexagon-loop-idiom", 277 "Recognize Hexagon-specific loop idioms", false, false) 278 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 279 INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 280 INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) 281 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) 282 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 283 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 284 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 285 INITIALIZE_PASS_END(HexagonLoopIdiomRecognizeLegacyPass, "hexagon-loop-idiom", 286 "Recognize Hexagon-specific loop idioms", false, false) 287 288 template <typename FuncT> 289 void Simplifier::Context::traverse(Value *V, FuncT F) { 290 WorkListType Q; 291 Q.push_back(V); 292 293 while (!Q.empty()) { 294 Instruction *U = dyn_cast<Instruction>(Q.pop_front_val()); 295 if (!U || U->getParent()) 296 continue; 297 if (!F(U)) 298 continue; 299 for (Value *Op : U->operands()) 300 Q.push_back(Op); 301 } 302 } 303 304 void Simplifier::Context::print(raw_ostream &OS, const Value *V) const { 305 const auto *U = dyn_cast<const Instruction>(V); 306 if (!U) { 307 OS << V << '(' << *V << ')'; 308 return; 309 } 310 311 if (U->getParent()) { 312 OS << U << '('; 313 U->printAsOperand(OS, true); 314 OS << ')'; 315 return; 316 } 317 318 unsigned N = U->getNumOperands(); 319 if (N != 0) 320 OS << U << '('; 321 OS << U->getOpcodeName(); 322 for (const Value *Op : U->operands()) { 323 OS << ' '; 324 print(OS, Op); 325 } 326 if (N != 0) 327 OS << ')'; 328 } 329 330 void Simplifier::Context::initialize(Instruction *Exp) { 331 // Perform a deep clone of the expression, set Root to the root 332 // of the clone, and build a map from the cloned values to the 333 // original ones. 334 ValueMapType M; 335 BasicBlock *Block = Exp->getParent(); 336 WorkListType Q; 337 Q.push_back(Exp); 338 339 while (!Q.empty()) { 340 Value *V = Q.pop_front_val(); 341 if (M.find(V) != M.end()) 342 continue; 343 if (Instruction *U = dyn_cast<Instruction>(V)) { 344 if (isa<PHINode>(U) || U->getParent() != Block) 345 continue; 346 for (Value *Op : U->operands()) 347 Q.push_back(Op); 348 M.insert({U, U->clone()}); 349 } 350 } 351 352 for (std::pair<Value*,Value*> P : M) { 353 Instruction *U = cast<Instruction>(P.second); 354 for (unsigned i = 0, n = U->getNumOperands(); i != n; ++i) { 355 auto F = M.find(U->getOperand(i)); 356 if (F != M.end()) 357 U->setOperand(i, F->second); 358 } 359 } 360 361 auto R = M.find(Exp); 362 assert(R != M.end()); 363 Root = R->second; 364 365 record(Root); 366 use(Root); 367 } 368 369 void Simplifier::Context::record(Value *V) { 370 auto Record = [this](Instruction *U) -> bool { 371 Clones.insert(U); 372 return true; 373 }; 374 traverse(V, Record); 375 } 376 377 void Simplifier::Context::use(Value *V) { 378 auto Use = [this](Instruction *U) -> bool { 379 Used.insert(U); 380 return true; 381 }; 382 traverse(V, Use); 383 } 384 385 void Simplifier::Context::unuse(Value *V) { 386 if (!isa<Instruction>(V) || cast<Instruction>(V)->getParent() != nullptr) 387 return; 388 389 auto Unuse = [this](Instruction *U) -> bool { 390 if (!U->use_empty()) 391 return false; 392 Used.erase(U); 393 return true; 394 }; 395 traverse(V, Unuse); 396 } 397 398 Value *Simplifier::Context::subst(Value *Tree, Value *OldV, Value *NewV) { 399 if (Tree == OldV) 400 return NewV; 401 if (OldV == NewV) 402 return Tree; 403 404 WorkListType Q; 405 Q.push_back(Tree); 406 while (!Q.empty()) { 407 Instruction *U = dyn_cast<Instruction>(Q.pop_front_val()); 408 // If U is not an instruction, or it's not a clone, skip it. 409 if (!U || U->getParent()) 410 continue; 411 for (unsigned i = 0, n = U->getNumOperands(); i != n; ++i) { 412 Value *Op = U->getOperand(i); 413 if (Op == OldV) { 414 U->setOperand(i, NewV); 415 unuse(OldV); 416 } else { 417 Q.push_back(Op); 418 } 419 } 420 } 421 return Tree; 422 } 423 424 void Simplifier::Context::replace(Value *OldV, Value *NewV) { 425 if (Root == OldV) { 426 Root = NewV; 427 use(Root); 428 return; 429 } 430 431 // NewV may be a complex tree that has just been created by one of the 432 // transformation rules. We need to make sure that it is commoned with 433 // the existing Root to the maximum extent possible. 434 // Identify all subtrees of NewV (including NewV itself) that have 435 // equivalent counterparts in Root, and replace those subtrees with 436 // these counterparts. 437 WorkListType Q; 438 Q.push_back(NewV); 439 while (!Q.empty()) { 440 Value *V = Q.pop_front_val(); 441 Instruction *U = dyn_cast<Instruction>(V); 442 if (!U || U->getParent()) 443 continue; 444 if (Value *DupV = find(Root, V)) { 445 if (DupV != V) 446 NewV = subst(NewV, V, DupV); 447 } else { 448 for (Value *Op : U->operands()) 449 Q.push_back(Op); 450 } 451 } 452 453 // Now, simply replace OldV with NewV in Root. 454 Root = subst(Root, OldV, NewV); 455 use(Root); 456 } 457 458 void Simplifier::Context::cleanup() { 459 for (Value *V : Clones) { 460 Instruction *U = cast<Instruction>(V); 461 if (!U->getParent()) 462 U->dropAllReferences(); 463 } 464 465 for (Value *V : Clones) { 466 Instruction *U = cast<Instruction>(V); 467 if (!U->getParent()) 468 U->deleteValue(); 469 } 470 } 471 472 bool Simplifier::Context::equal(const Instruction *I, 473 const Instruction *J) const { 474 if (I == J) 475 return true; 476 if (!I->isSameOperationAs(J)) 477 return false; 478 if (isa<PHINode>(I)) 479 return I->isIdenticalTo(J); 480 481 for (unsigned i = 0, n = I->getNumOperands(); i != n; ++i) { 482 Value *OpI = I->getOperand(i), *OpJ = J->getOperand(i); 483 if (OpI == OpJ) 484 continue; 485 auto *InI = dyn_cast<const Instruction>(OpI); 486 auto *InJ = dyn_cast<const Instruction>(OpJ); 487 if (InI && InJ) { 488 if (!equal(InI, InJ)) 489 return false; 490 } else if (InI != InJ || !InI) 491 return false; 492 } 493 return true; 494 } 495 496 Value *Simplifier::Context::find(Value *Tree, Value *Sub) const { 497 Instruction *SubI = dyn_cast<Instruction>(Sub); 498 WorkListType Q; 499 Q.push_back(Tree); 500 501 while (!Q.empty()) { 502 Value *V = Q.pop_front_val(); 503 if (V == Sub) 504 return V; 505 Instruction *U = dyn_cast<Instruction>(V); 506 if (!U || U->getParent()) 507 continue; 508 if (SubI && equal(SubI, U)) 509 return U; 510 assert(!isa<PHINode>(U)); 511 for (Value *Op : U->operands()) 512 Q.push_back(Op); 513 } 514 return nullptr; 515 } 516 517 void Simplifier::Context::link(Instruction *I, BasicBlock *B, 518 BasicBlock::iterator At) { 519 if (I->getParent()) 520 return; 521 522 for (Value *Op : I->operands()) { 523 if (Instruction *OpI = dyn_cast<Instruction>(Op)) 524 link(OpI, B, At); 525 } 526 527 B->getInstList().insert(At, I); 528 } 529 530 Value *Simplifier::Context::materialize(BasicBlock *B, 531 BasicBlock::iterator At) { 532 if (Instruction *RootI = dyn_cast<Instruction>(Root)) 533 link(RootI, B, At); 534 return Root; 535 } 536 537 Value *Simplifier::simplify(Context &C) { 538 WorkListType Q; 539 Q.push_back(C.Root); 540 unsigned Count = 0; 541 const unsigned Limit = SimplifyLimit; 542 543 while (!Q.empty()) { 544 if (Count++ >= Limit) 545 break; 546 Instruction *U = dyn_cast<Instruction>(Q.pop_front_val()); 547 if (!U || U->getParent() || !C.Used.count(U)) 548 continue; 549 bool Changed = false; 550 for (Rule &R : Rules) { 551 Value *W = R.Fn(U, C.Ctx); 552 if (!W) 553 continue; 554 Changed = true; 555 C.record(W); 556 C.replace(U, W); 557 Q.push_back(C.Root); 558 break; 559 } 560 if (!Changed) { 561 for (Value *Op : U->operands()) 562 Q.push_back(Op); 563 } 564 } 565 return Count < Limit ? C.Root : nullptr; 566 } 567 568 //===----------------------------------------------------------------------===// 569 // 570 // Implementation of PolynomialMultiplyRecognize 571 // 572 //===----------------------------------------------------------------------===// 573 574 namespace { 575 576 class PolynomialMultiplyRecognize { 577 public: 578 explicit PolynomialMultiplyRecognize(Loop *loop, const DataLayout &dl, 579 const DominatorTree &dt, const TargetLibraryInfo &tli, 580 ScalarEvolution &se) 581 : CurLoop(loop), DL(dl), DT(dt), TLI(tli), SE(se) {} 582 583 bool recognize(); 584 585 private: 586 using ValueSeq = SetVector<Value *>; 587 588 IntegerType *getPmpyType() const { 589 LLVMContext &Ctx = CurLoop->getHeader()->getParent()->getContext(); 590 return IntegerType::get(Ctx, 32); 591 } 592 593 bool isPromotableTo(Value *V, IntegerType *Ty); 594 void promoteTo(Instruction *In, IntegerType *DestTy, BasicBlock *LoopB); 595 bool promoteTypes(BasicBlock *LoopB, BasicBlock *ExitB); 596 597 Value *getCountIV(BasicBlock *BB); 598 bool findCycle(Value *Out, Value *In, ValueSeq &Cycle); 599 void classifyCycle(Instruction *DivI, ValueSeq &Cycle, ValueSeq &Early, 600 ValueSeq &Late); 601 bool classifyInst(Instruction *UseI, ValueSeq &Early, ValueSeq &Late); 602 bool commutesWithShift(Instruction *I); 603 bool highBitsAreZero(Value *V, unsigned IterCount); 604 bool keepsHighBitsZero(Value *V, unsigned IterCount); 605 bool isOperandShifted(Instruction *I, Value *Op); 606 bool convertShiftsToLeft(BasicBlock *LoopB, BasicBlock *ExitB, 607 unsigned IterCount); 608 void cleanupLoopBody(BasicBlock *LoopB); 609 610 struct ParsedValues { 611 ParsedValues() = default; 612 613 Value *M = nullptr; 614 Value *P = nullptr; 615 Value *Q = nullptr; 616 Value *R = nullptr; 617 Value *X = nullptr; 618 Instruction *Res = nullptr; 619 unsigned IterCount = 0; 620 bool Left = false; 621 bool Inv = false; 622 }; 623 624 bool matchLeftShift(SelectInst *SelI, Value *CIV, ParsedValues &PV); 625 bool matchRightShift(SelectInst *SelI, ParsedValues &PV); 626 bool scanSelect(SelectInst *SI, BasicBlock *LoopB, BasicBlock *PrehB, 627 Value *CIV, ParsedValues &PV, bool PreScan); 628 unsigned getInverseMxN(unsigned QP); 629 Value *generate(BasicBlock::iterator At, ParsedValues &PV); 630 631 void setupPreSimplifier(Simplifier &S); 632 void setupPostSimplifier(Simplifier &S); 633 634 Loop *CurLoop; 635 const DataLayout &DL; 636 const DominatorTree &DT; 637 const TargetLibraryInfo &TLI; 638 ScalarEvolution &SE; 639 }; 640 641 } // end anonymous namespace 642 643 Value *PolynomialMultiplyRecognize::getCountIV(BasicBlock *BB) { 644 pred_iterator PI = pred_begin(BB), PE = pred_end(BB); 645 if (std::distance(PI, PE) != 2) 646 return nullptr; 647 BasicBlock *PB = (*PI == BB) ? *std::next(PI) : *PI; 648 649 for (auto I = BB->begin(), E = BB->end(); I != E && isa<PHINode>(I); ++I) { 650 auto *PN = cast<PHINode>(I); 651 Value *InitV = PN->getIncomingValueForBlock(PB); 652 if (!isa<ConstantInt>(InitV) || !cast<ConstantInt>(InitV)->isZero()) 653 continue; 654 Value *IterV = PN->getIncomingValueForBlock(BB); 655 auto *BO = dyn_cast<BinaryOperator>(IterV); 656 if (!BO) 657 continue; 658 if (BO->getOpcode() != Instruction::Add) 659 continue; 660 Value *IncV = nullptr; 661 if (BO->getOperand(0) == PN) 662 IncV = BO->getOperand(1); 663 else if (BO->getOperand(1) == PN) 664 IncV = BO->getOperand(0); 665 if (IncV == nullptr) 666 continue; 667 668 if (auto *T = dyn_cast<ConstantInt>(IncV)) 669 if (T->getZExtValue() == 1) 670 return PN; 671 } 672 return nullptr; 673 } 674 675 static void replaceAllUsesOfWithIn(Value *I, Value *J, BasicBlock *BB) { 676 for (auto UI = I->user_begin(), UE = I->user_end(); UI != UE;) { 677 Use &TheUse = UI.getUse(); 678 ++UI; 679 if (auto *II = dyn_cast<Instruction>(TheUse.getUser())) 680 if (BB == II->getParent()) 681 II->replaceUsesOfWith(I, J); 682 } 683 } 684 685 bool PolynomialMultiplyRecognize::matchLeftShift(SelectInst *SelI, 686 Value *CIV, ParsedValues &PV) { 687 // Match the following: 688 // select (X & (1 << i)) != 0 ? R ^ (Q << i) : R 689 // select (X & (1 << i)) == 0 ? R : R ^ (Q << i) 690 // The condition may also check for equality with the masked value, i.e 691 // select (X & (1 << i)) == (1 << i) ? R ^ (Q << i) : R 692 // select (X & (1 << i)) != (1 << i) ? R : R ^ (Q << i); 693 694 Value *CondV = SelI->getCondition(); 695 Value *TrueV = SelI->getTrueValue(); 696 Value *FalseV = SelI->getFalseValue(); 697 698 using namespace PatternMatch; 699 700 CmpInst::Predicate P; 701 Value *A = nullptr, *B = nullptr, *C = nullptr; 702 703 if (!match(CondV, m_ICmp(P, m_And(m_Value(A), m_Value(B)), m_Value(C))) && 704 !match(CondV, m_ICmp(P, m_Value(C), m_And(m_Value(A), m_Value(B))))) 705 return false; 706 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE) 707 return false; 708 // Matched: select (A & B) == C ? ... : ... 709 // select (A & B) != C ? ... : ... 710 711 Value *X = nullptr, *Sh1 = nullptr; 712 // Check (A & B) for (X & (1 << i)): 713 if (match(A, m_Shl(m_One(), m_Specific(CIV)))) { 714 Sh1 = A; 715 X = B; 716 } else if (match(B, m_Shl(m_One(), m_Specific(CIV)))) { 717 Sh1 = B; 718 X = A; 719 } else { 720 // TODO: Could also check for an induction variable containing single 721 // bit shifted left by 1 in each iteration. 722 return false; 723 } 724 725 bool TrueIfZero; 726 727 // Check C against the possible values for comparison: 0 and (1 << i): 728 if (match(C, m_Zero())) 729 TrueIfZero = (P == CmpInst::ICMP_EQ); 730 else if (C == Sh1) 731 TrueIfZero = (P == CmpInst::ICMP_NE); 732 else 733 return false; 734 735 // So far, matched: 736 // select (X & (1 << i)) ? ... : ... 737 // including variations of the check against zero/non-zero value. 738 739 Value *ShouldSameV = nullptr, *ShouldXoredV = nullptr; 740 if (TrueIfZero) { 741 ShouldSameV = TrueV; 742 ShouldXoredV = FalseV; 743 } else { 744 ShouldSameV = FalseV; 745 ShouldXoredV = TrueV; 746 } 747 748 Value *Q = nullptr, *R = nullptr, *Y = nullptr, *Z = nullptr; 749 Value *T = nullptr; 750 if (match(ShouldXoredV, m_Xor(m_Value(Y), m_Value(Z)))) { 751 // Matched: select +++ ? ... : Y ^ Z 752 // select +++ ? Y ^ Z : ... 753 // where +++ denotes previously checked matches. 754 if (ShouldSameV == Y) 755 T = Z; 756 else if (ShouldSameV == Z) 757 T = Y; 758 else 759 return false; 760 R = ShouldSameV; 761 // Matched: select +++ ? R : R ^ T 762 // select +++ ? R ^ T : R 763 // depending on TrueIfZero. 764 765 } else if (match(ShouldSameV, m_Zero())) { 766 // Matched: select +++ ? 0 : ... 767 // select +++ ? ... : 0 768 if (!SelI->hasOneUse()) 769 return false; 770 T = ShouldXoredV; 771 // Matched: select +++ ? 0 : T 772 // select +++ ? T : 0 773 774 Value *U = *SelI->user_begin(); 775 if (!match(U, m_Xor(m_Specific(SelI), m_Value(R))) && 776 !match(U, m_Xor(m_Value(R), m_Specific(SelI)))) 777 return false; 778 // Matched: xor (select +++ ? 0 : T), R 779 // xor (select +++ ? T : 0), R 780 } else 781 return false; 782 783 // The xor input value T is isolated into its own match so that it could 784 // be checked against an induction variable containing a shifted bit 785 // (todo). 786 // For now, check against (Q << i). 787 if (!match(T, m_Shl(m_Value(Q), m_Specific(CIV))) && 788 !match(T, m_Shl(m_ZExt(m_Value(Q)), m_ZExt(m_Specific(CIV))))) 789 return false; 790 // Matched: select +++ ? R : R ^ (Q << i) 791 // select +++ ? R ^ (Q << i) : R 792 793 PV.X = X; 794 PV.Q = Q; 795 PV.R = R; 796 PV.Left = true; 797 return true; 798 } 799 800 bool PolynomialMultiplyRecognize::matchRightShift(SelectInst *SelI, 801 ParsedValues &PV) { 802 // Match the following: 803 // select (X & 1) != 0 ? (R >> 1) ^ Q : (R >> 1) 804 // select (X & 1) == 0 ? (R >> 1) : (R >> 1) ^ Q 805 // The condition may also check for equality with the masked value, i.e 806 // select (X & 1) == 1 ? (R >> 1) ^ Q : (R >> 1) 807 // select (X & 1) != 1 ? (R >> 1) : (R >> 1) ^ Q 808 809 Value *CondV = SelI->getCondition(); 810 Value *TrueV = SelI->getTrueValue(); 811 Value *FalseV = SelI->getFalseValue(); 812 813 using namespace PatternMatch; 814 815 Value *C = nullptr; 816 CmpInst::Predicate P; 817 bool TrueIfZero; 818 819 if (match(CondV, m_ICmp(P, m_Value(C), m_Zero())) || 820 match(CondV, m_ICmp(P, m_Zero(), m_Value(C)))) { 821 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE) 822 return false; 823 // Matched: select C == 0 ? ... : ... 824 // select C != 0 ? ... : ... 825 TrueIfZero = (P == CmpInst::ICMP_EQ); 826 } else if (match(CondV, m_ICmp(P, m_Value(C), m_One())) || 827 match(CondV, m_ICmp(P, m_One(), m_Value(C)))) { 828 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE) 829 return false; 830 // Matched: select C == 1 ? ... : ... 831 // select C != 1 ? ... : ... 832 TrueIfZero = (P == CmpInst::ICMP_NE); 833 } else 834 return false; 835 836 Value *X = nullptr; 837 if (!match(C, m_And(m_Value(X), m_One())) && 838 !match(C, m_And(m_One(), m_Value(X)))) 839 return false; 840 // Matched: select (X & 1) == +++ ? ... : ... 841 // select (X & 1) != +++ ? ... : ... 842 843 Value *R = nullptr, *Q = nullptr; 844 if (TrueIfZero) { 845 // The select's condition is true if the tested bit is 0. 846 // TrueV must be the shift, FalseV must be the xor. 847 if (!match(TrueV, m_LShr(m_Value(R), m_One()))) 848 return false; 849 // Matched: select +++ ? (R >> 1) : ... 850 if (!match(FalseV, m_Xor(m_Specific(TrueV), m_Value(Q))) && 851 !match(FalseV, m_Xor(m_Value(Q), m_Specific(TrueV)))) 852 return false; 853 // Matched: select +++ ? (R >> 1) : (R >> 1) ^ Q 854 // with commuting ^. 855 } else { 856 // The select's condition is true if the tested bit is 1. 857 // TrueV must be the xor, FalseV must be the shift. 858 if (!match(FalseV, m_LShr(m_Value(R), m_One()))) 859 return false; 860 // Matched: select +++ ? ... : (R >> 1) 861 if (!match(TrueV, m_Xor(m_Specific(FalseV), m_Value(Q))) && 862 !match(TrueV, m_Xor(m_Value(Q), m_Specific(FalseV)))) 863 return false; 864 // Matched: select +++ ? (R >> 1) ^ Q : (R >> 1) 865 // with commuting ^. 866 } 867 868 PV.X = X; 869 PV.Q = Q; 870 PV.R = R; 871 PV.Left = false; 872 return true; 873 } 874 875 bool PolynomialMultiplyRecognize::scanSelect(SelectInst *SelI, 876 BasicBlock *LoopB, BasicBlock *PrehB, Value *CIV, ParsedValues &PV, 877 bool PreScan) { 878 using namespace PatternMatch; 879 880 // The basic pattern for R = P.Q is: 881 // for i = 0..31 882 // R = phi (0, R') 883 // if (P & (1 << i)) ; test-bit(P, i) 884 // R' = R ^ (Q << i) 885 // 886 // Similarly, the basic pattern for R = (P/Q).Q - P 887 // for i = 0..31 888 // R = phi(P, R') 889 // if (R & (1 << i)) 890 // R' = R ^ (Q << i) 891 892 // There exist idioms, where instead of Q being shifted left, P is shifted 893 // right. This produces a result that is shifted right by 32 bits (the 894 // non-shifted result is 64-bit). 895 // 896 // For R = P.Q, this would be: 897 // for i = 0..31 898 // R = phi (0, R') 899 // if ((P >> i) & 1) 900 // R' = (R >> 1) ^ Q ; R is cycled through the loop, so it must 901 // else ; be shifted by 1, not i. 902 // R' = R >> 1 903 // 904 // And for the inverse: 905 // for i = 0..31 906 // R = phi (P, R') 907 // if (R & 1) 908 // R' = (R >> 1) ^ Q 909 // else 910 // R' = R >> 1 911 912 // The left-shifting idioms share the same pattern: 913 // select (X & (1 << i)) ? R ^ (Q << i) : R 914 // Similarly for right-shifting idioms: 915 // select (X & 1) ? (R >> 1) ^ Q 916 917 if (matchLeftShift(SelI, CIV, PV)) { 918 // If this is a pre-scan, getting this far is sufficient. 919 if (PreScan) 920 return true; 921 922 // Need to make sure that the SelI goes back into R. 923 auto *RPhi = dyn_cast<PHINode>(PV.R); 924 if (!RPhi) 925 return false; 926 if (SelI != RPhi->getIncomingValueForBlock(LoopB)) 927 return false; 928 PV.Res = SelI; 929 930 // If X is loop invariant, it must be the input polynomial, and the 931 // idiom is the basic polynomial multiply. 932 if (CurLoop->isLoopInvariant(PV.X)) { 933 PV.P = PV.X; 934 PV.Inv = false; 935 } else { 936 // X is not loop invariant. If X == R, this is the inverse pmpy. 937 // Otherwise, check for an xor with an invariant value. If the 938 // variable argument to the xor is R, then this is still a valid 939 // inverse pmpy. 940 PV.Inv = true; 941 if (PV.X != PV.R) { 942 Value *Var = nullptr, *Inv = nullptr, *X1 = nullptr, *X2 = nullptr; 943 if (!match(PV.X, m_Xor(m_Value(X1), m_Value(X2)))) 944 return false; 945 auto *I1 = dyn_cast<Instruction>(X1); 946 auto *I2 = dyn_cast<Instruction>(X2); 947 if (!I1 || I1->getParent() != LoopB) { 948 Var = X2; 949 Inv = X1; 950 } else if (!I2 || I2->getParent() != LoopB) { 951 Var = X1; 952 Inv = X2; 953 } else 954 return false; 955 if (Var != PV.R) 956 return false; 957 PV.M = Inv; 958 } 959 // The input polynomial P still needs to be determined. It will be 960 // the entry value of R. 961 Value *EntryP = RPhi->getIncomingValueForBlock(PrehB); 962 PV.P = EntryP; 963 } 964 965 return true; 966 } 967 968 if (matchRightShift(SelI, PV)) { 969 // If this is an inverse pattern, the Q polynomial must be known at 970 // compile time. 971 if (PV.Inv && !isa<ConstantInt>(PV.Q)) 972 return false; 973 if (PreScan) 974 return true; 975 // There is no exact matching of right-shift pmpy. 976 return false; 977 } 978 979 return false; 980 } 981 982 bool PolynomialMultiplyRecognize::isPromotableTo(Value *Val, 983 IntegerType *DestTy) { 984 IntegerType *T = dyn_cast<IntegerType>(Val->getType()); 985 if (!T || T->getBitWidth() > DestTy->getBitWidth()) 986 return false; 987 if (T->getBitWidth() == DestTy->getBitWidth()) 988 return true; 989 // Non-instructions are promotable. The reason why an instruction may not 990 // be promotable is that it may produce a different result if its operands 991 // and the result are promoted, for example, it may produce more non-zero 992 // bits. While it would still be possible to represent the proper result 993 // in a wider type, it may require adding additional instructions (which 994 // we don't want to do). 995 Instruction *In = dyn_cast<Instruction>(Val); 996 if (!In) 997 return true; 998 // The bitwidth of the source type is smaller than the destination. 999 // Check if the individual operation can be promoted. 1000 switch (In->getOpcode()) { 1001 case Instruction::PHI: 1002 case Instruction::ZExt: 1003 case Instruction::And: 1004 case Instruction::Or: 1005 case Instruction::Xor: 1006 case Instruction::LShr: // Shift right is ok. 1007 case Instruction::Select: 1008 case Instruction::Trunc: 1009 return true; 1010 case Instruction::ICmp: 1011 if (CmpInst *CI = cast<CmpInst>(In)) 1012 return CI->isEquality() || CI->isUnsigned(); 1013 llvm_unreachable("Cast failed unexpectedly"); 1014 case Instruction::Add: 1015 return In->hasNoSignedWrap() && In->hasNoUnsignedWrap(); 1016 } 1017 return false; 1018 } 1019 1020 void PolynomialMultiplyRecognize::promoteTo(Instruction *In, 1021 IntegerType *DestTy, BasicBlock *LoopB) { 1022 Type *OrigTy = In->getType(); 1023 assert(!OrigTy->isVoidTy() && "Invalid instruction to promote"); 1024 1025 // Leave boolean values alone. 1026 if (!In->getType()->isIntegerTy(1)) 1027 In->mutateType(DestTy); 1028 unsigned DestBW = DestTy->getBitWidth(); 1029 1030 // Handle PHIs. 1031 if (PHINode *P = dyn_cast<PHINode>(In)) { 1032 unsigned N = P->getNumIncomingValues(); 1033 for (unsigned i = 0; i != N; ++i) { 1034 BasicBlock *InB = P->getIncomingBlock(i); 1035 if (InB == LoopB) 1036 continue; 1037 Value *InV = P->getIncomingValue(i); 1038 IntegerType *Ty = cast<IntegerType>(InV->getType()); 1039 // Do not promote values in PHI nodes of type i1. 1040 if (Ty != P->getType()) { 1041 // If the value type does not match the PHI type, the PHI type 1042 // must have been promoted. 1043 assert(Ty->getBitWidth() < DestBW); 1044 InV = IRBuilder<>(InB->getTerminator()).CreateZExt(InV, DestTy); 1045 P->setIncomingValue(i, InV); 1046 } 1047 } 1048 } else if (ZExtInst *Z = dyn_cast<ZExtInst>(In)) { 1049 Value *Op = Z->getOperand(0); 1050 if (Op->getType() == Z->getType()) 1051 Z->replaceAllUsesWith(Op); 1052 Z->eraseFromParent(); 1053 return; 1054 } 1055 if (TruncInst *T = dyn_cast<TruncInst>(In)) { 1056 IntegerType *TruncTy = cast<IntegerType>(OrigTy); 1057 Value *Mask = ConstantInt::get(DestTy, (1u << TruncTy->getBitWidth()) - 1); 1058 Value *And = IRBuilder<>(In).CreateAnd(T->getOperand(0), Mask); 1059 T->replaceAllUsesWith(And); 1060 T->eraseFromParent(); 1061 return; 1062 } 1063 1064 // Promote immediates. 1065 for (unsigned i = 0, n = In->getNumOperands(); i != n; ++i) { 1066 if (ConstantInt *CI = dyn_cast<ConstantInt>(In->getOperand(i))) 1067 if (CI->getType()->getBitWidth() < DestBW) 1068 In->setOperand(i, ConstantInt::get(DestTy, CI->getZExtValue())); 1069 } 1070 } 1071 1072 bool PolynomialMultiplyRecognize::promoteTypes(BasicBlock *LoopB, 1073 BasicBlock *ExitB) { 1074 assert(LoopB); 1075 // Skip loops where the exit block has more than one predecessor. The values 1076 // coming from the loop block will be promoted to another type, and so the 1077 // values coming into the exit block from other predecessors would also have 1078 // to be promoted. 1079 if (!ExitB || (ExitB->getSinglePredecessor() != LoopB)) 1080 return false; 1081 IntegerType *DestTy = getPmpyType(); 1082 // Check if the exit values have types that are no wider than the type 1083 // that we want to promote to. 1084 unsigned DestBW = DestTy->getBitWidth(); 1085 for (PHINode &P : ExitB->phis()) { 1086 if (P.getNumIncomingValues() != 1) 1087 return false; 1088 assert(P.getIncomingBlock(0) == LoopB); 1089 IntegerType *T = dyn_cast<IntegerType>(P.getType()); 1090 if (!T || T->getBitWidth() > DestBW) 1091 return false; 1092 } 1093 1094 // Check all instructions in the loop. 1095 for (Instruction &In : *LoopB) 1096 if (!In.isTerminator() && !isPromotableTo(&In, DestTy)) 1097 return false; 1098 1099 // Perform the promotion. 1100 std::vector<Instruction*> LoopIns; 1101 std::transform(LoopB->begin(), LoopB->end(), std::back_inserter(LoopIns), 1102 [](Instruction &In) { return &In; }); 1103 for (Instruction *In : LoopIns) 1104 if (!In->isTerminator()) 1105 promoteTo(In, DestTy, LoopB); 1106 1107 // Fix up the PHI nodes in the exit block. 1108 Instruction *EndI = ExitB->getFirstNonPHI(); 1109 BasicBlock::iterator End = EndI ? EndI->getIterator() : ExitB->end(); 1110 for (auto I = ExitB->begin(); I != End; ++I) { 1111 PHINode *P = dyn_cast<PHINode>(I); 1112 if (!P) 1113 break; 1114 Type *Ty0 = P->getIncomingValue(0)->getType(); 1115 Type *PTy = P->getType(); 1116 if (PTy != Ty0) { 1117 assert(Ty0 == DestTy); 1118 // In order to create the trunc, P must have the promoted type. 1119 P->mutateType(Ty0); 1120 Value *T = IRBuilder<>(ExitB, End).CreateTrunc(P, PTy); 1121 // In order for the RAUW to work, the types of P and T must match. 1122 P->mutateType(PTy); 1123 P->replaceAllUsesWith(T); 1124 // Final update of the P's type. 1125 P->mutateType(Ty0); 1126 cast<Instruction>(T)->setOperand(0, P); 1127 } 1128 } 1129 1130 return true; 1131 } 1132 1133 bool PolynomialMultiplyRecognize::findCycle(Value *Out, Value *In, 1134 ValueSeq &Cycle) { 1135 // Out = ..., In, ... 1136 if (Out == In) 1137 return true; 1138 1139 auto *BB = cast<Instruction>(Out)->getParent(); 1140 bool HadPhi = false; 1141 1142 for (auto U : Out->users()) { 1143 auto *I = dyn_cast<Instruction>(&*U); 1144 if (I == nullptr || I->getParent() != BB) 1145 continue; 1146 // Make sure that there are no multi-iteration cycles, e.g. 1147 // p1 = phi(p2) 1148 // p2 = phi(p1) 1149 // The cycle p1->p2->p1 would span two loop iterations. 1150 // Check that there is only one phi in the cycle. 1151 bool IsPhi = isa<PHINode>(I); 1152 if (IsPhi && HadPhi) 1153 return false; 1154 HadPhi |= IsPhi; 1155 if (Cycle.count(I)) 1156 return false; 1157 Cycle.insert(I); 1158 if (findCycle(I, In, Cycle)) 1159 break; 1160 Cycle.remove(I); 1161 } 1162 return !Cycle.empty(); 1163 } 1164 1165 void PolynomialMultiplyRecognize::classifyCycle(Instruction *DivI, 1166 ValueSeq &Cycle, ValueSeq &Early, ValueSeq &Late) { 1167 // All the values in the cycle that are between the phi node and the 1168 // divider instruction will be classified as "early", all other values 1169 // will be "late". 1170 1171 bool IsE = true; 1172 unsigned I, N = Cycle.size(); 1173 for (I = 0; I < N; ++I) { 1174 Value *V = Cycle[I]; 1175 if (DivI == V) 1176 IsE = false; 1177 else if (!isa<PHINode>(V)) 1178 continue; 1179 // Stop if found either. 1180 break; 1181 } 1182 // "I" is the index of either DivI or the phi node, whichever was first. 1183 // "E" is "false" or "true" respectively. 1184 ValueSeq &First = !IsE ? Early : Late; 1185 for (unsigned J = 0; J < I; ++J) 1186 First.insert(Cycle[J]); 1187 1188 ValueSeq &Second = IsE ? Early : Late; 1189 Second.insert(Cycle[I]); 1190 for (++I; I < N; ++I) { 1191 Value *V = Cycle[I]; 1192 if (DivI == V || isa<PHINode>(V)) 1193 break; 1194 Second.insert(V); 1195 } 1196 1197 for (; I < N; ++I) 1198 First.insert(Cycle[I]); 1199 } 1200 1201 bool PolynomialMultiplyRecognize::classifyInst(Instruction *UseI, 1202 ValueSeq &Early, ValueSeq &Late) { 1203 // Select is an exception, since the condition value does not have to be 1204 // classified in the same way as the true/false values. The true/false 1205 // values do have to be both early or both late. 1206 if (UseI->getOpcode() == Instruction::Select) { 1207 Value *TV = UseI->getOperand(1), *FV = UseI->getOperand(2); 1208 if (Early.count(TV) || Early.count(FV)) { 1209 if (Late.count(TV) || Late.count(FV)) 1210 return false; 1211 Early.insert(UseI); 1212 } else if (Late.count(TV) || Late.count(FV)) { 1213 if (Early.count(TV) || Early.count(FV)) 1214 return false; 1215 Late.insert(UseI); 1216 } 1217 return true; 1218 } 1219 1220 // Not sure what would be the example of this, but the code below relies 1221 // on having at least one operand. 1222 if (UseI->getNumOperands() == 0) 1223 return true; 1224 1225 bool AE = true, AL = true; 1226 for (auto &I : UseI->operands()) { 1227 if (Early.count(&*I)) 1228 AL = false; 1229 else if (Late.count(&*I)) 1230 AE = false; 1231 } 1232 // If the operands appear "all early" and "all late" at the same time, 1233 // then it means that none of them are actually classified as either. 1234 // This is harmless. 1235 if (AE && AL) 1236 return true; 1237 // Conversely, if they are neither "all early" nor "all late", then 1238 // we have a mixture of early and late operands that is not a known 1239 // exception. 1240 if (!AE && !AL) 1241 return false; 1242 1243 // Check that we have covered the two special cases. 1244 assert(AE != AL); 1245 1246 if (AE) 1247 Early.insert(UseI); 1248 else 1249 Late.insert(UseI); 1250 return true; 1251 } 1252 1253 bool PolynomialMultiplyRecognize::commutesWithShift(Instruction *I) { 1254 switch (I->getOpcode()) { 1255 case Instruction::And: 1256 case Instruction::Or: 1257 case Instruction::Xor: 1258 case Instruction::LShr: 1259 case Instruction::Shl: 1260 case Instruction::Select: 1261 case Instruction::ICmp: 1262 case Instruction::PHI: 1263 break; 1264 default: 1265 return false; 1266 } 1267 return true; 1268 } 1269 1270 bool PolynomialMultiplyRecognize::highBitsAreZero(Value *V, 1271 unsigned IterCount) { 1272 auto *T = dyn_cast<IntegerType>(V->getType()); 1273 if (!T) 1274 return false; 1275 1276 KnownBits Known(T->getBitWidth()); 1277 computeKnownBits(V, Known, DL); 1278 return Known.countMinLeadingZeros() >= IterCount; 1279 } 1280 1281 bool PolynomialMultiplyRecognize::keepsHighBitsZero(Value *V, 1282 unsigned IterCount) { 1283 // Assume that all inputs to the value have the high bits zero. 1284 // Check if the value itself preserves the zeros in the high bits. 1285 if (auto *C = dyn_cast<ConstantInt>(V)) 1286 return C->getValue().countLeadingZeros() >= IterCount; 1287 1288 if (auto *I = dyn_cast<Instruction>(V)) { 1289 switch (I->getOpcode()) { 1290 case Instruction::And: 1291 case Instruction::Or: 1292 case Instruction::Xor: 1293 case Instruction::LShr: 1294 case Instruction::Select: 1295 case Instruction::ICmp: 1296 case Instruction::PHI: 1297 case Instruction::ZExt: 1298 return true; 1299 } 1300 } 1301 1302 return false; 1303 } 1304 1305 bool PolynomialMultiplyRecognize::isOperandShifted(Instruction *I, Value *Op) { 1306 unsigned Opc = I->getOpcode(); 1307 if (Opc == Instruction::Shl || Opc == Instruction::LShr) 1308 return Op != I->getOperand(1); 1309 return true; 1310 } 1311 1312 bool PolynomialMultiplyRecognize::convertShiftsToLeft(BasicBlock *LoopB, 1313 BasicBlock *ExitB, unsigned IterCount) { 1314 Value *CIV = getCountIV(LoopB); 1315 if (CIV == nullptr) 1316 return false; 1317 auto *CIVTy = dyn_cast<IntegerType>(CIV->getType()); 1318 if (CIVTy == nullptr) 1319 return false; 1320 1321 ValueSeq RShifts; 1322 ValueSeq Early, Late, Cycled; 1323 1324 // Find all value cycles that contain logical right shifts by 1. 1325 for (Instruction &I : *LoopB) { 1326 using namespace PatternMatch; 1327 1328 Value *V = nullptr; 1329 if (!match(&I, m_LShr(m_Value(V), m_One()))) 1330 continue; 1331 ValueSeq C; 1332 if (!findCycle(&I, V, C)) 1333 continue; 1334 1335 // Found a cycle. 1336 C.insert(&I); 1337 classifyCycle(&I, C, Early, Late); 1338 Cycled.insert(C.begin(), C.end()); 1339 RShifts.insert(&I); 1340 } 1341 1342 // Find the set of all values affected by the shift cycles, i.e. all 1343 // cycled values, and (recursively) all their users. 1344 ValueSeq Users(Cycled.begin(), Cycled.end()); 1345 for (unsigned i = 0; i < Users.size(); ++i) { 1346 Value *V = Users[i]; 1347 if (!isa<IntegerType>(V->getType())) 1348 return false; 1349 auto *R = cast<Instruction>(V); 1350 // If the instruction does not commute with shifts, the loop cannot 1351 // be unshifted. 1352 if (!commutesWithShift(R)) 1353 return false; 1354 for (auto I = R->user_begin(), E = R->user_end(); I != E; ++I) { 1355 auto *T = cast<Instruction>(*I); 1356 // Skip users from outside of the loop. They will be handled later. 1357 // Also, skip the right-shifts and phi nodes, since they mix early 1358 // and late values. 1359 if (T->getParent() != LoopB || RShifts.count(T) || isa<PHINode>(T)) 1360 continue; 1361 1362 Users.insert(T); 1363 if (!classifyInst(T, Early, Late)) 1364 return false; 1365 } 1366 } 1367 1368 if (Users.empty()) 1369 return false; 1370 1371 // Verify that high bits remain zero. 1372 ValueSeq Internal(Users.begin(), Users.end()); 1373 ValueSeq Inputs; 1374 for (unsigned i = 0; i < Internal.size(); ++i) { 1375 auto *R = dyn_cast<Instruction>(Internal[i]); 1376 if (!R) 1377 continue; 1378 for (Value *Op : R->operands()) { 1379 auto *T = dyn_cast<Instruction>(Op); 1380 if (T && T->getParent() != LoopB) 1381 Inputs.insert(Op); 1382 else 1383 Internal.insert(Op); 1384 } 1385 } 1386 for (Value *V : Inputs) 1387 if (!highBitsAreZero(V, IterCount)) 1388 return false; 1389 for (Value *V : Internal) 1390 if (!keepsHighBitsZero(V, IterCount)) 1391 return false; 1392 1393 // Finally, the work can be done. Unshift each user. 1394 IRBuilder<> IRB(LoopB); 1395 std::map<Value*,Value*> ShiftMap; 1396 1397 using CastMapType = std::map<std::pair<Value *, Type *>, Value *>; 1398 1399 CastMapType CastMap; 1400 1401 auto upcast = [] (CastMapType &CM, IRBuilder<> &IRB, Value *V, 1402 IntegerType *Ty) -> Value* { 1403 auto H = CM.find(std::make_pair(V, Ty)); 1404 if (H != CM.end()) 1405 return H->second; 1406 Value *CV = IRB.CreateIntCast(V, Ty, false); 1407 CM.insert(std::make_pair(std::make_pair(V, Ty), CV)); 1408 return CV; 1409 }; 1410 1411 for (auto I = LoopB->begin(), E = LoopB->end(); I != E; ++I) { 1412 using namespace PatternMatch; 1413 1414 if (isa<PHINode>(I) || !Users.count(&*I)) 1415 continue; 1416 1417 // Match lshr x, 1. 1418 Value *V = nullptr; 1419 if (match(&*I, m_LShr(m_Value(V), m_One()))) { 1420 replaceAllUsesOfWithIn(&*I, V, LoopB); 1421 continue; 1422 } 1423 // For each non-cycled operand, replace it with the corresponding 1424 // value shifted left. 1425 for (auto &J : I->operands()) { 1426 Value *Op = J.get(); 1427 if (!isOperandShifted(&*I, Op)) 1428 continue; 1429 if (Users.count(Op)) 1430 continue; 1431 // Skip shifting zeros. 1432 if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero()) 1433 continue; 1434 // Check if we have already generated a shift for this value. 1435 auto F = ShiftMap.find(Op); 1436 Value *W = (F != ShiftMap.end()) ? F->second : nullptr; 1437 if (W == nullptr) { 1438 IRB.SetInsertPoint(&*I); 1439 // First, the shift amount will be CIV or CIV+1, depending on 1440 // whether the value is early or late. Instead of creating CIV+1, 1441 // do a single shift of the value. 1442 Value *ShAmt = CIV, *ShVal = Op; 1443 auto *VTy = cast<IntegerType>(ShVal->getType()); 1444 auto *ATy = cast<IntegerType>(ShAmt->getType()); 1445 if (Late.count(&*I)) 1446 ShVal = IRB.CreateShl(Op, ConstantInt::get(VTy, 1)); 1447 // Second, the types of the shifted value and the shift amount 1448 // must match. 1449 if (VTy != ATy) { 1450 if (VTy->getBitWidth() < ATy->getBitWidth()) 1451 ShVal = upcast(CastMap, IRB, ShVal, ATy); 1452 else 1453 ShAmt = upcast(CastMap, IRB, ShAmt, VTy); 1454 } 1455 // Ready to generate the shift and memoize it. 1456 W = IRB.CreateShl(ShVal, ShAmt); 1457 ShiftMap.insert(std::make_pair(Op, W)); 1458 } 1459 I->replaceUsesOfWith(Op, W); 1460 } 1461 } 1462 1463 // Update the users outside of the loop to account for having left 1464 // shifts. They would normally be shifted right in the loop, so shift 1465 // them right after the loop exit. 1466 // Take advantage of the loop-closed SSA form, which has all the post- 1467 // loop values in phi nodes. 1468 IRB.SetInsertPoint(ExitB, ExitB->getFirstInsertionPt()); 1469 for (auto P = ExitB->begin(), Q = ExitB->end(); P != Q; ++P) { 1470 if (!isa<PHINode>(P)) 1471 break; 1472 auto *PN = cast<PHINode>(P); 1473 Value *U = PN->getIncomingValueForBlock(LoopB); 1474 if (!Users.count(U)) 1475 continue; 1476 Value *S = IRB.CreateLShr(PN, ConstantInt::get(PN->getType(), IterCount)); 1477 PN->replaceAllUsesWith(S); 1478 // The above RAUW will create 1479 // S = lshr S, IterCount 1480 // so we need to fix it back into 1481 // S = lshr PN, IterCount 1482 cast<User>(S)->replaceUsesOfWith(S, PN); 1483 } 1484 1485 return true; 1486 } 1487 1488 void PolynomialMultiplyRecognize::cleanupLoopBody(BasicBlock *LoopB) { 1489 for (auto &I : *LoopB) 1490 if (Value *SV = SimplifyInstruction(&I, {DL, &TLI, &DT})) 1491 I.replaceAllUsesWith(SV); 1492 1493 for (auto I = LoopB->begin(), N = I; I != LoopB->end(); I = N) { 1494 N = std::next(I); 1495 RecursivelyDeleteTriviallyDeadInstructions(&*I, &TLI); 1496 } 1497 } 1498 1499 unsigned PolynomialMultiplyRecognize::getInverseMxN(unsigned QP) { 1500 // Arrays of coefficients of Q and the inverse, C. 1501 // Q[i] = coefficient at x^i. 1502 std::array<char,32> Q, C; 1503 1504 for (unsigned i = 0; i < 32; ++i) { 1505 Q[i] = QP & 1; 1506 QP >>= 1; 1507 } 1508 assert(Q[0] == 1); 1509 1510 // Find C, such that 1511 // (Q[n]*x^n + ... + Q[1]*x + Q[0]) * (C[n]*x^n + ... + C[1]*x + C[0]) = 1 1512 // 1513 // For it to have a solution, Q[0] must be 1. Since this is Z2[x], the 1514 // operations * and + are & and ^ respectively. 1515 // 1516 // Find C[i] recursively, by comparing i-th coefficient in the product 1517 // with 0 (or 1 for i=0). 1518 // 1519 // C[0] = 1, since C[0] = Q[0], and Q[0] = 1. 1520 C[0] = 1; 1521 for (unsigned i = 1; i < 32; ++i) { 1522 // Solve for C[i] in: 1523 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i]Q[0] = 0 1524 // This is equivalent to 1525 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i] = 0 1526 // which is 1527 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] = C[i] 1528 unsigned T = 0; 1529 for (unsigned j = 0; j < i; ++j) 1530 T = T ^ (C[j] & Q[i-j]); 1531 C[i] = T; 1532 } 1533 1534 unsigned QV = 0; 1535 for (unsigned i = 0; i < 32; ++i) 1536 if (C[i]) 1537 QV |= (1 << i); 1538 1539 return QV; 1540 } 1541 1542 Value *PolynomialMultiplyRecognize::generate(BasicBlock::iterator At, 1543 ParsedValues &PV) { 1544 IRBuilder<> B(&*At); 1545 Module *M = At->getParent()->getParent()->getParent(); 1546 Function *PMF = Intrinsic::getDeclaration(M, Intrinsic::hexagon_M4_pmpyw); 1547 1548 Value *P = PV.P, *Q = PV.Q, *P0 = P; 1549 unsigned IC = PV.IterCount; 1550 1551 if (PV.M != nullptr) 1552 P0 = P = B.CreateXor(P, PV.M); 1553 1554 // Create a bit mask to clear the high bits beyond IterCount. 1555 auto *BMI = ConstantInt::get(P->getType(), APInt::getLowBitsSet(32, IC)); 1556 1557 if (PV.IterCount != 32) 1558 P = B.CreateAnd(P, BMI); 1559 1560 if (PV.Inv) { 1561 auto *QI = dyn_cast<ConstantInt>(PV.Q); 1562 assert(QI && QI->getBitWidth() <= 32); 1563 1564 // Again, clearing bits beyond IterCount. 1565 unsigned M = (1 << PV.IterCount) - 1; 1566 unsigned Tmp = (QI->getZExtValue() | 1) & M; 1567 unsigned QV = getInverseMxN(Tmp) & M; 1568 auto *QVI = ConstantInt::get(QI->getType(), QV); 1569 P = B.CreateCall(PMF, {P, QVI}); 1570 P = B.CreateTrunc(P, QI->getType()); 1571 if (IC != 32) 1572 P = B.CreateAnd(P, BMI); 1573 } 1574 1575 Value *R = B.CreateCall(PMF, {P, Q}); 1576 1577 if (PV.M != nullptr) 1578 R = B.CreateXor(R, B.CreateIntCast(P0, R->getType(), false)); 1579 1580 return R; 1581 } 1582 1583 static bool hasZeroSignBit(const Value *V) { 1584 if (const auto *CI = dyn_cast<const ConstantInt>(V)) 1585 return (CI->getType()->getSignBit() & CI->getSExtValue()) == 0; 1586 const Instruction *I = dyn_cast<const Instruction>(V); 1587 if (!I) 1588 return false; 1589 switch (I->getOpcode()) { 1590 case Instruction::LShr: 1591 if (const auto SI = dyn_cast<const ConstantInt>(I->getOperand(1))) 1592 return SI->getZExtValue() > 0; 1593 return false; 1594 case Instruction::Or: 1595 case Instruction::Xor: 1596 return hasZeroSignBit(I->getOperand(0)) && 1597 hasZeroSignBit(I->getOperand(1)); 1598 case Instruction::And: 1599 return hasZeroSignBit(I->getOperand(0)) || 1600 hasZeroSignBit(I->getOperand(1)); 1601 } 1602 return false; 1603 } 1604 1605 void PolynomialMultiplyRecognize::setupPreSimplifier(Simplifier &S) { 1606 S.addRule("sink-zext", 1607 // Sink zext past bitwise operations. 1608 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1609 if (I->getOpcode() != Instruction::ZExt) 1610 return nullptr; 1611 Instruction *T = dyn_cast<Instruction>(I->getOperand(0)); 1612 if (!T) 1613 return nullptr; 1614 switch (T->getOpcode()) { 1615 case Instruction::And: 1616 case Instruction::Or: 1617 case Instruction::Xor: 1618 break; 1619 default: 1620 return nullptr; 1621 } 1622 IRBuilder<> B(Ctx); 1623 return B.CreateBinOp(cast<BinaryOperator>(T)->getOpcode(), 1624 B.CreateZExt(T->getOperand(0), I->getType()), 1625 B.CreateZExt(T->getOperand(1), I->getType())); 1626 }); 1627 S.addRule("xor/and -> and/xor", 1628 // (xor (and x a) (and y a)) -> (and (xor x y) a) 1629 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1630 if (I->getOpcode() != Instruction::Xor) 1631 return nullptr; 1632 Instruction *And0 = dyn_cast<Instruction>(I->getOperand(0)); 1633 Instruction *And1 = dyn_cast<Instruction>(I->getOperand(1)); 1634 if (!And0 || !And1) 1635 return nullptr; 1636 if (And0->getOpcode() != Instruction::And || 1637 And1->getOpcode() != Instruction::And) 1638 return nullptr; 1639 if (And0->getOperand(1) != And1->getOperand(1)) 1640 return nullptr; 1641 IRBuilder<> B(Ctx); 1642 return B.CreateAnd(B.CreateXor(And0->getOperand(0), And1->getOperand(0)), 1643 And0->getOperand(1)); 1644 }); 1645 S.addRule("sink binop into select", 1646 // (Op (select c x y) z) -> (select c (Op x z) (Op y z)) 1647 // (Op x (select c y z)) -> (select c (Op x y) (Op x z)) 1648 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1649 BinaryOperator *BO = dyn_cast<BinaryOperator>(I); 1650 if (!BO) 1651 return nullptr; 1652 Instruction::BinaryOps Op = BO->getOpcode(); 1653 if (SelectInst *Sel = dyn_cast<SelectInst>(BO->getOperand(0))) { 1654 IRBuilder<> B(Ctx); 1655 Value *X = Sel->getTrueValue(), *Y = Sel->getFalseValue(); 1656 Value *Z = BO->getOperand(1); 1657 return B.CreateSelect(Sel->getCondition(), 1658 B.CreateBinOp(Op, X, Z), 1659 B.CreateBinOp(Op, Y, Z)); 1660 } 1661 if (SelectInst *Sel = dyn_cast<SelectInst>(BO->getOperand(1))) { 1662 IRBuilder<> B(Ctx); 1663 Value *X = BO->getOperand(0); 1664 Value *Y = Sel->getTrueValue(), *Z = Sel->getFalseValue(); 1665 return B.CreateSelect(Sel->getCondition(), 1666 B.CreateBinOp(Op, X, Y), 1667 B.CreateBinOp(Op, X, Z)); 1668 } 1669 return nullptr; 1670 }); 1671 S.addRule("fold select-select", 1672 // (select c (select c x y) z) -> (select c x z) 1673 // (select c x (select c y z)) -> (select c x z) 1674 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1675 SelectInst *Sel = dyn_cast<SelectInst>(I); 1676 if (!Sel) 1677 return nullptr; 1678 IRBuilder<> B(Ctx); 1679 Value *C = Sel->getCondition(); 1680 if (SelectInst *Sel0 = dyn_cast<SelectInst>(Sel->getTrueValue())) { 1681 if (Sel0->getCondition() == C) 1682 return B.CreateSelect(C, Sel0->getTrueValue(), Sel->getFalseValue()); 1683 } 1684 if (SelectInst *Sel1 = dyn_cast<SelectInst>(Sel->getFalseValue())) { 1685 if (Sel1->getCondition() == C) 1686 return B.CreateSelect(C, Sel->getTrueValue(), Sel1->getFalseValue()); 1687 } 1688 return nullptr; 1689 }); 1690 S.addRule("or-signbit -> xor-signbit", 1691 // (or (lshr x 1) 0x800.0) -> (xor (lshr x 1) 0x800.0) 1692 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1693 if (I->getOpcode() != Instruction::Or) 1694 return nullptr; 1695 ConstantInt *Msb = dyn_cast<ConstantInt>(I->getOperand(1)); 1696 if (!Msb || Msb->getZExtValue() != Msb->getType()->getSignBit()) 1697 return nullptr; 1698 if (!hasZeroSignBit(I->getOperand(0))) 1699 return nullptr; 1700 return IRBuilder<>(Ctx).CreateXor(I->getOperand(0), Msb); 1701 }); 1702 S.addRule("sink lshr into binop", 1703 // (lshr (BitOp x y) c) -> (BitOp (lshr x c) (lshr y c)) 1704 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1705 if (I->getOpcode() != Instruction::LShr) 1706 return nullptr; 1707 BinaryOperator *BitOp = dyn_cast<BinaryOperator>(I->getOperand(0)); 1708 if (!BitOp) 1709 return nullptr; 1710 switch (BitOp->getOpcode()) { 1711 case Instruction::And: 1712 case Instruction::Or: 1713 case Instruction::Xor: 1714 break; 1715 default: 1716 return nullptr; 1717 } 1718 IRBuilder<> B(Ctx); 1719 Value *S = I->getOperand(1); 1720 return B.CreateBinOp(BitOp->getOpcode(), 1721 B.CreateLShr(BitOp->getOperand(0), S), 1722 B.CreateLShr(BitOp->getOperand(1), S)); 1723 }); 1724 S.addRule("expose bitop-const", 1725 // (BitOp1 (BitOp2 x a) b) -> (BitOp2 x (BitOp1 a b)) 1726 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1727 auto IsBitOp = [](unsigned Op) -> bool { 1728 switch (Op) { 1729 case Instruction::And: 1730 case Instruction::Or: 1731 case Instruction::Xor: 1732 return true; 1733 } 1734 return false; 1735 }; 1736 BinaryOperator *BitOp1 = dyn_cast<BinaryOperator>(I); 1737 if (!BitOp1 || !IsBitOp(BitOp1->getOpcode())) 1738 return nullptr; 1739 BinaryOperator *BitOp2 = dyn_cast<BinaryOperator>(BitOp1->getOperand(0)); 1740 if (!BitOp2 || !IsBitOp(BitOp2->getOpcode())) 1741 return nullptr; 1742 ConstantInt *CA = dyn_cast<ConstantInt>(BitOp2->getOperand(1)); 1743 ConstantInt *CB = dyn_cast<ConstantInt>(BitOp1->getOperand(1)); 1744 if (!CA || !CB) 1745 return nullptr; 1746 IRBuilder<> B(Ctx); 1747 Value *X = BitOp2->getOperand(0); 1748 return B.CreateBinOp(BitOp2->getOpcode(), X, 1749 B.CreateBinOp(BitOp1->getOpcode(), CA, CB)); 1750 }); 1751 } 1752 1753 void PolynomialMultiplyRecognize::setupPostSimplifier(Simplifier &S) { 1754 S.addRule("(and (xor (and x a) y) b) -> (and (xor x y) b), if b == b&a", 1755 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1756 if (I->getOpcode() != Instruction::And) 1757 return nullptr; 1758 Instruction *Xor = dyn_cast<Instruction>(I->getOperand(0)); 1759 ConstantInt *C0 = dyn_cast<ConstantInt>(I->getOperand(1)); 1760 if (!Xor || !C0) 1761 return nullptr; 1762 if (Xor->getOpcode() != Instruction::Xor) 1763 return nullptr; 1764 Instruction *And0 = dyn_cast<Instruction>(Xor->getOperand(0)); 1765 Instruction *And1 = dyn_cast<Instruction>(Xor->getOperand(1)); 1766 // Pick the first non-null and. 1767 if (!And0 || And0->getOpcode() != Instruction::And) 1768 std::swap(And0, And1); 1769 ConstantInt *C1 = dyn_cast<ConstantInt>(And0->getOperand(1)); 1770 if (!C1) 1771 return nullptr; 1772 uint32_t V0 = C0->getZExtValue(); 1773 uint32_t V1 = C1->getZExtValue(); 1774 if (V0 != (V0 & V1)) 1775 return nullptr; 1776 IRBuilder<> B(Ctx); 1777 return B.CreateAnd(B.CreateXor(And0->getOperand(0), And1), C0); 1778 }); 1779 } 1780 1781 bool PolynomialMultiplyRecognize::recognize() { 1782 LLVM_DEBUG(dbgs() << "Starting PolynomialMultiplyRecognize on loop\n" 1783 << *CurLoop << '\n'); 1784 // Restrictions: 1785 // - The loop must consist of a single block. 1786 // - The iteration count must be known at compile-time. 1787 // - The loop must have an induction variable starting from 0, and 1788 // incremented in each iteration of the loop. 1789 BasicBlock *LoopB = CurLoop->getHeader(); 1790 LLVM_DEBUG(dbgs() << "Loop header:\n" << *LoopB); 1791 1792 if (LoopB != CurLoop->getLoopLatch()) 1793 return false; 1794 BasicBlock *ExitB = CurLoop->getExitBlock(); 1795 if (ExitB == nullptr) 1796 return false; 1797 BasicBlock *EntryB = CurLoop->getLoopPreheader(); 1798 if (EntryB == nullptr) 1799 return false; 1800 1801 unsigned IterCount = 0; 1802 const SCEV *CT = SE.getBackedgeTakenCount(CurLoop); 1803 if (isa<SCEVCouldNotCompute>(CT)) 1804 return false; 1805 if (auto *CV = dyn_cast<SCEVConstant>(CT)) 1806 IterCount = CV->getValue()->getZExtValue() + 1; 1807 1808 Value *CIV = getCountIV(LoopB); 1809 ParsedValues PV; 1810 Simplifier PreSimp; 1811 PV.IterCount = IterCount; 1812 LLVM_DEBUG(dbgs() << "Loop IV: " << *CIV << "\nIterCount: " << IterCount 1813 << '\n'); 1814 1815 setupPreSimplifier(PreSimp); 1816 1817 // Perform a preliminary scan of select instructions to see if any of them 1818 // looks like a generator of the polynomial multiply steps. Assume that a 1819 // loop can only contain a single transformable operation, so stop the 1820 // traversal after the first reasonable candidate was found. 1821 // XXX: Currently this approach can modify the loop before being 100% sure 1822 // that the transformation can be carried out. 1823 bool FoundPreScan = false; 1824 auto FeedsPHI = [LoopB](const Value *V) -> bool { 1825 for (const Value *U : V->users()) { 1826 if (const auto *P = dyn_cast<const PHINode>(U)) 1827 if (P->getParent() == LoopB) 1828 return true; 1829 } 1830 return false; 1831 }; 1832 for (Instruction &In : *LoopB) { 1833 SelectInst *SI = dyn_cast<SelectInst>(&In); 1834 if (!SI || !FeedsPHI(SI)) 1835 continue; 1836 1837 Simplifier::Context C(SI); 1838 Value *T = PreSimp.simplify(C); 1839 SelectInst *SelI = (T && isa<SelectInst>(T)) ? cast<SelectInst>(T) : SI; 1840 LLVM_DEBUG(dbgs() << "scanSelect(pre-scan): " << PE(C, SelI) << '\n'); 1841 if (scanSelect(SelI, LoopB, EntryB, CIV, PV, true)) { 1842 FoundPreScan = true; 1843 if (SelI != SI) { 1844 Value *NewSel = C.materialize(LoopB, SI->getIterator()); 1845 SI->replaceAllUsesWith(NewSel); 1846 RecursivelyDeleteTriviallyDeadInstructions(SI, &TLI); 1847 } 1848 break; 1849 } 1850 } 1851 1852 if (!FoundPreScan) { 1853 LLVM_DEBUG(dbgs() << "Have not found candidates for pmpy\n"); 1854 return false; 1855 } 1856 1857 if (!PV.Left) { 1858 // The right shift version actually only returns the higher bits of 1859 // the result (each iteration discards the LSB). If we want to convert it 1860 // to a left-shifting loop, the working data type must be at least as 1861 // wide as the target's pmpy instruction. 1862 if (!promoteTypes(LoopB, ExitB)) 1863 return false; 1864 // Run post-promotion simplifications. 1865 Simplifier PostSimp; 1866 setupPostSimplifier(PostSimp); 1867 for (Instruction &In : *LoopB) { 1868 SelectInst *SI = dyn_cast<SelectInst>(&In); 1869 if (!SI || !FeedsPHI(SI)) 1870 continue; 1871 Simplifier::Context C(SI); 1872 Value *T = PostSimp.simplify(C); 1873 SelectInst *SelI = dyn_cast_or_null<SelectInst>(T); 1874 if (SelI != SI) { 1875 Value *NewSel = C.materialize(LoopB, SI->getIterator()); 1876 SI->replaceAllUsesWith(NewSel); 1877 RecursivelyDeleteTriviallyDeadInstructions(SI, &TLI); 1878 } 1879 break; 1880 } 1881 1882 if (!convertShiftsToLeft(LoopB, ExitB, IterCount)) 1883 return false; 1884 cleanupLoopBody(LoopB); 1885 } 1886 1887 // Scan the loop again, find the generating select instruction. 1888 bool FoundScan = false; 1889 for (Instruction &In : *LoopB) { 1890 SelectInst *SelI = dyn_cast<SelectInst>(&In); 1891 if (!SelI) 1892 continue; 1893 LLVM_DEBUG(dbgs() << "scanSelect: " << *SelI << '\n'); 1894 FoundScan = scanSelect(SelI, LoopB, EntryB, CIV, PV, false); 1895 if (FoundScan) 1896 break; 1897 } 1898 assert(FoundScan); 1899 1900 LLVM_DEBUG({ 1901 StringRef PP = (PV.M ? "(P+M)" : "P"); 1902 if (!PV.Inv) 1903 dbgs() << "Found pmpy idiom: R = " << PP << ".Q\n"; 1904 else 1905 dbgs() << "Found inverse pmpy idiom: R = (" << PP << "/Q).Q) + " 1906 << PP << "\n"; 1907 dbgs() << " Res:" << *PV.Res << "\n P:" << *PV.P << "\n"; 1908 if (PV.M) 1909 dbgs() << " M:" << *PV.M << "\n"; 1910 dbgs() << " Q:" << *PV.Q << "\n"; 1911 dbgs() << " Iteration count:" << PV.IterCount << "\n"; 1912 }); 1913 1914 BasicBlock::iterator At(EntryB->getTerminator()); 1915 Value *PM = generate(At, PV); 1916 if (PM == nullptr) 1917 return false; 1918 1919 if (PM->getType() != PV.Res->getType()) 1920 PM = IRBuilder<>(&*At).CreateIntCast(PM, PV.Res->getType(), false); 1921 1922 PV.Res->replaceAllUsesWith(PM); 1923 PV.Res->eraseFromParent(); 1924 return true; 1925 } 1926 1927 int HexagonLoopIdiomRecognize::getSCEVStride(const SCEVAddRecExpr *S) { 1928 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getOperand(1))) 1929 return SC->getAPInt().getSExtValue(); 1930 return 0; 1931 } 1932 1933 bool HexagonLoopIdiomRecognize::isLegalStore(Loop *CurLoop, StoreInst *SI) { 1934 // Allow volatile stores if HexagonVolatileMemcpy is enabled. 1935 if (!(SI->isVolatile() && HexagonVolatileMemcpy) && !SI->isSimple()) 1936 return false; 1937 1938 Value *StoredVal = SI->getValueOperand(); 1939 Value *StorePtr = SI->getPointerOperand(); 1940 1941 // Reject stores that are so large that they overflow an unsigned. 1942 uint64_t SizeInBits = DL->getTypeSizeInBits(StoredVal->getType()); 1943 if ((SizeInBits & 7) || (SizeInBits >> 32) != 0) 1944 return false; 1945 1946 // See if the pointer expression is an AddRec like {base,+,1} on the current 1947 // loop, which indicates a strided store. If we have something else, it's a 1948 // random store we can't handle. 1949 auto *StoreEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); 1950 if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine()) 1951 return false; 1952 1953 // Check to see if the stride matches the size of the store. If so, then we 1954 // know that every byte is touched in the loop. 1955 int Stride = getSCEVStride(StoreEv); 1956 if (Stride == 0) 1957 return false; 1958 unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType()); 1959 if (StoreSize != unsigned(std::abs(Stride))) 1960 return false; 1961 1962 // The store must be feeding a non-volatile load. 1963 LoadInst *LI = dyn_cast<LoadInst>(SI->getValueOperand()); 1964 if (!LI || !LI->isSimple()) 1965 return false; 1966 1967 // See if the pointer expression is an AddRec like {base,+,1} on the current 1968 // loop, which indicates a strided load. If we have something else, it's a 1969 // random load we can't handle. 1970 Value *LoadPtr = LI->getPointerOperand(); 1971 auto *LoadEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(LoadPtr)); 1972 if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine()) 1973 return false; 1974 1975 // The store and load must share the same stride. 1976 if (StoreEv->getOperand(1) != LoadEv->getOperand(1)) 1977 return false; 1978 1979 // Success. This store can be converted into a memcpy. 1980 return true; 1981 } 1982 1983 /// mayLoopAccessLocation - Return true if the specified loop might access the 1984 /// specified pointer location, which is a loop-strided access. The 'Access' 1985 /// argument specifies what the verboten forms of access are (read or write). 1986 static bool 1987 mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, 1988 const SCEV *BECount, unsigned StoreSize, 1989 AliasAnalysis &AA, 1990 SmallPtrSetImpl<Instruction *> &Ignored) { 1991 // Get the location that may be stored across the loop. Since the access 1992 // is strided positively through memory, we say that the modified location 1993 // starts at the pointer and has infinite size. 1994 LocationSize AccessSize = LocationSize::afterPointer(); 1995 1996 // If the loop iterates a fixed number of times, we can refine the access 1997 // size to be exactly the size of the memset, which is (BECount+1)*StoreSize 1998 if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount)) 1999 AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) * 2000 StoreSize); 2001 2002 // TODO: For this to be really effective, we have to dive into the pointer 2003 // operand in the store. Store to &A[i] of 100 will always return may alias 2004 // with store of &A[100], we need to StoreLoc to be "A" with size of 100, 2005 // which will then no-alias a store to &A[100]. 2006 MemoryLocation StoreLoc(Ptr, AccessSize); 2007 2008 for (auto *B : L->blocks()) 2009 for (auto &I : *B) 2010 if (Ignored.count(&I) == 0 && 2011 isModOrRefSet( 2012 intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access))) 2013 return true; 2014 2015 return false; 2016 } 2017 2018 void HexagonLoopIdiomRecognize::collectStores(Loop *CurLoop, BasicBlock *BB, 2019 SmallVectorImpl<StoreInst*> &Stores) { 2020 Stores.clear(); 2021 for (Instruction &I : *BB) 2022 if (StoreInst *SI = dyn_cast<StoreInst>(&I)) 2023 if (isLegalStore(CurLoop, SI)) 2024 Stores.push_back(SI); 2025 } 2026 2027 bool HexagonLoopIdiomRecognize::processCopyingStore(Loop *CurLoop, 2028 StoreInst *SI, const SCEV *BECount) { 2029 assert((SI->isSimple() || (SI->isVolatile() && HexagonVolatileMemcpy)) && 2030 "Expected only non-volatile stores, or Hexagon-specific memcpy" 2031 "to volatile destination."); 2032 2033 Value *StorePtr = SI->getPointerOperand(); 2034 auto *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); 2035 unsigned Stride = getSCEVStride(StoreEv); 2036 unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType()); 2037 if (Stride != StoreSize) 2038 return false; 2039 2040 // See if the pointer expression is an AddRec like {base,+,1} on the current 2041 // loop, which indicates a strided load. If we have something else, it's a 2042 // random load we can't handle. 2043 auto *LI = cast<LoadInst>(SI->getValueOperand()); 2044 auto *LoadEv = cast<SCEVAddRecExpr>(SE->getSCEV(LI->getPointerOperand())); 2045 2046 // The trip count of the loop and the base pointer of the addrec SCEV is 2047 // guaranteed to be loop invariant, which means that it should dominate the 2048 // header. This allows us to insert code for it in the preheader. 2049 BasicBlock *Preheader = CurLoop->getLoopPreheader(); 2050 Instruction *ExpPt = Preheader->getTerminator(); 2051 IRBuilder<> Builder(ExpPt); 2052 SCEVExpander Expander(*SE, *DL, "hexagon-loop-idiom"); 2053 2054 Type *IntPtrTy = Builder.getIntPtrTy(*DL, SI->getPointerAddressSpace()); 2055 2056 // Okay, we have a strided store "p[i]" of a loaded value. We can turn 2057 // this into a memcpy/memmove in the loop preheader now if we want. However, 2058 // this would be unsafe to do if there is anything else in the loop that may 2059 // read or write the memory region we're storing to. For memcpy, this 2060 // includes the load that feeds the stores. Check for an alias by generating 2061 // the base address and checking everything. 2062 Value *StoreBasePtr = Expander.expandCodeFor(StoreEv->getStart(), 2063 Builder.getInt8PtrTy(SI->getPointerAddressSpace()), ExpPt); 2064 Value *LoadBasePtr = nullptr; 2065 2066 bool Overlap = false; 2067 bool DestVolatile = SI->isVolatile(); 2068 Type *BECountTy = BECount->getType(); 2069 2070 if (DestVolatile) { 2071 // The trip count must fit in i32, since it is the type of the "num_words" 2072 // argument to hexagon_memcpy_forward_vp4cp4n2. 2073 if (StoreSize != 4 || DL->getTypeSizeInBits(BECountTy) > 32) { 2074 CleanupAndExit: 2075 // If we generated new code for the base pointer, clean up. 2076 Expander.clear(); 2077 if (StoreBasePtr && (LoadBasePtr != StoreBasePtr)) { 2078 RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI); 2079 StoreBasePtr = nullptr; 2080 } 2081 if (LoadBasePtr) { 2082 RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr, TLI); 2083 LoadBasePtr = nullptr; 2084 } 2085 return false; 2086 } 2087 } 2088 2089 SmallPtrSet<Instruction*, 2> Ignore1; 2090 Ignore1.insert(SI); 2091 if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount, 2092 StoreSize, *AA, Ignore1)) { 2093 // Check if the load is the offending instruction. 2094 Ignore1.insert(LI); 2095 if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, 2096 BECount, StoreSize, *AA, Ignore1)) { 2097 // Still bad. Nothing we can do. 2098 goto CleanupAndExit; 2099 } 2100 // It worked with the load ignored. 2101 Overlap = true; 2102 } 2103 2104 if (!Overlap) { 2105 if (DisableMemcpyIdiom || !HasMemcpy) 2106 goto CleanupAndExit; 2107 } else { 2108 // Don't generate memmove if this function will be inlined. This is 2109 // because the caller will undergo this transformation after inlining. 2110 Function *Func = CurLoop->getHeader()->getParent(); 2111 if (Func->hasFnAttribute(Attribute::AlwaysInline)) 2112 goto CleanupAndExit; 2113 2114 // In case of a memmove, the call to memmove will be executed instead 2115 // of the loop, so we need to make sure that there is nothing else in 2116 // the loop than the load, store and instructions that these two depend 2117 // on. 2118 SmallVector<Instruction*,2> Insts; 2119 Insts.push_back(SI); 2120 Insts.push_back(LI); 2121 if (!coverLoop(CurLoop, Insts)) 2122 goto CleanupAndExit; 2123 2124 if (DisableMemmoveIdiom || !HasMemmove) 2125 goto CleanupAndExit; 2126 bool IsNested = CurLoop->getParentLoop() != nullptr; 2127 if (IsNested && OnlyNonNestedMemmove) 2128 goto CleanupAndExit; 2129 } 2130 2131 // For a memcpy, we have to make sure that the input array is not being 2132 // mutated by the loop. 2133 LoadBasePtr = Expander.expandCodeFor(LoadEv->getStart(), 2134 Builder.getInt8PtrTy(LI->getPointerAddressSpace()), ExpPt); 2135 2136 SmallPtrSet<Instruction*, 2> Ignore2; 2137 Ignore2.insert(SI); 2138 if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount, 2139 StoreSize, *AA, Ignore2)) 2140 goto CleanupAndExit; 2141 2142 // Check the stride. 2143 bool StridePos = getSCEVStride(LoadEv) >= 0; 2144 2145 // Currently, the volatile memcpy only emulates traversing memory forward. 2146 if (!StridePos && DestVolatile) 2147 goto CleanupAndExit; 2148 2149 bool RuntimeCheck = (Overlap || DestVolatile); 2150 2151 BasicBlock *ExitB; 2152 if (RuntimeCheck) { 2153 // The runtime check needs a single exit block. 2154 SmallVector<BasicBlock*, 8> ExitBlocks; 2155 CurLoop->getUniqueExitBlocks(ExitBlocks); 2156 if (ExitBlocks.size() != 1) 2157 goto CleanupAndExit; 2158 ExitB = ExitBlocks[0]; 2159 } 2160 2161 // The # stored bytes is (BECount+1)*Size. Expand the trip count out to 2162 // pointer size if it isn't already. 2163 LLVMContext &Ctx = SI->getContext(); 2164 BECount = SE->getTruncateOrZeroExtend(BECount, IntPtrTy); 2165 DebugLoc DLoc = SI->getDebugLoc(); 2166 2167 const SCEV *NumBytesS = 2168 SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW); 2169 if (StoreSize != 1) 2170 NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize), 2171 SCEV::FlagNUW); 2172 Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntPtrTy, ExpPt); 2173 if (Instruction *In = dyn_cast<Instruction>(NumBytes)) 2174 if (Value *Simp = SimplifyInstruction(In, {*DL, TLI, DT})) 2175 NumBytes = Simp; 2176 2177 CallInst *NewCall; 2178 2179 if (RuntimeCheck) { 2180 unsigned Threshold = RuntimeMemSizeThreshold; 2181 if (ConstantInt *CI = dyn_cast<ConstantInt>(NumBytes)) { 2182 uint64_t C = CI->getZExtValue(); 2183 if (Threshold != 0 && C < Threshold) 2184 goto CleanupAndExit; 2185 if (C < CompileTimeMemSizeThreshold) 2186 goto CleanupAndExit; 2187 } 2188 2189 BasicBlock *Header = CurLoop->getHeader(); 2190 Function *Func = Header->getParent(); 2191 Loop *ParentL = LF->getLoopFor(Preheader); 2192 StringRef HeaderName = Header->getName(); 2193 2194 // Create a new (empty) preheader, and update the PHI nodes in the 2195 // header to use the new preheader. 2196 BasicBlock *NewPreheader = BasicBlock::Create(Ctx, HeaderName+".rtli.ph", 2197 Func, Header); 2198 if (ParentL) 2199 ParentL->addBasicBlockToLoop(NewPreheader, *LF); 2200 IRBuilder<>(NewPreheader).CreateBr(Header); 2201 for (auto &In : *Header) { 2202 PHINode *PN = dyn_cast<PHINode>(&In); 2203 if (!PN) 2204 break; 2205 int bx = PN->getBasicBlockIndex(Preheader); 2206 if (bx >= 0) 2207 PN->setIncomingBlock(bx, NewPreheader); 2208 } 2209 DT->addNewBlock(NewPreheader, Preheader); 2210 DT->changeImmediateDominator(Header, NewPreheader); 2211 2212 // Check for safe conditions to execute memmove. 2213 // If stride is positive, copying things from higher to lower addresses 2214 // is equivalent to memmove. For negative stride, it's the other way 2215 // around. Copying forward in memory with positive stride may not be 2216 // same as memmove since we may be copying values that we just stored 2217 // in some previous iteration. 2218 Value *LA = Builder.CreatePtrToInt(LoadBasePtr, IntPtrTy); 2219 Value *SA = Builder.CreatePtrToInt(StoreBasePtr, IntPtrTy); 2220 Value *LowA = StridePos ? SA : LA; 2221 Value *HighA = StridePos ? LA : SA; 2222 Value *CmpA = Builder.CreateICmpULT(LowA, HighA); 2223 Value *Cond = CmpA; 2224 2225 // Check for distance between pointers. Since the case LowA < HighA 2226 // is checked for above, assume LowA >= HighA. 2227 Value *Dist = Builder.CreateSub(LowA, HighA); 2228 Value *CmpD = Builder.CreateICmpSLE(NumBytes, Dist); 2229 Value *CmpEither = Builder.CreateOr(Cond, CmpD); 2230 Cond = CmpEither; 2231 2232 if (Threshold != 0) { 2233 Type *Ty = NumBytes->getType(); 2234 Value *Thr = ConstantInt::get(Ty, Threshold); 2235 Value *CmpB = Builder.CreateICmpULT(Thr, NumBytes); 2236 Value *CmpBoth = Builder.CreateAnd(Cond, CmpB); 2237 Cond = CmpBoth; 2238 } 2239 BasicBlock *MemmoveB = BasicBlock::Create(Ctx, Header->getName()+".rtli", 2240 Func, NewPreheader); 2241 if (ParentL) 2242 ParentL->addBasicBlockToLoop(MemmoveB, *LF); 2243 Instruction *OldT = Preheader->getTerminator(); 2244 Builder.CreateCondBr(Cond, MemmoveB, NewPreheader); 2245 OldT->eraseFromParent(); 2246 Preheader->setName(Preheader->getName()+".old"); 2247 DT->addNewBlock(MemmoveB, Preheader); 2248 // Find the new immediate dominator of the exit block. 2249 BasicBlock *ExitD = Preheader; 2250 for (auto PI = pred_begin(ExitB), PE = pred_end(ExitB); PI != PE; ++PI) { 2251 BasicBlock *PB = *PI; 2252 ExitD = DT->findNearestCommonDominator(ExitD, PB); 2253 if (!ExitD) 2254 break; 2255 } 2256 // If the prior immediate dominator of ExitB was dominated by the 2257 // old preheader, then the old preheader becomes the new immediate 2258 // dominator. Otherwise don't change anything (because the newly 2259 // added blocks are dominated by the old preheader). 2260 if (ExitD && DT->dominates(Preheader, ExitD)) { 2261 DomTreeNode *BN = DT->getNode(ExitB); 2262 DomTreeNode *DN = DT->getNode(ExitD); 2263 BN->setIDom(DN); 2264 } 2265 2266 // Add a call to memmove to the conditional block. 2267 IRBuilder<> CondBuilder(MemmoveB); 2268 CondBuilder.CreateBr(ExitB); 2269 CondBuilder.SetInsertPoint(MemmoveB->getTerminator()); 2270 2271 if (DestVolatile) { 2272 Type *Int32Ty = Type::getInt32Ty(Ctx); 2273 Type *Int32PtrTy = Type::getInt32PtrTy(Ctx); 2274 Type *VoidTy = Type::getVoidTy(Ctx); 2275 Module *M = Func->getParent(); 2276 FunctionCallee Fn = M->getOrInsertFunction( 2277 HexagonVolatileMemcpyName, VoidTy, Int32PtrTy, Int32PtrTy, Int32Ty); 2278 2279 const SCEV *OneS = SE->getConstant(Int32Ty, 1); 2280 const SCEV *BECount32 = SE->getTruncateOrZeroExtend(BECount, Int32Ty); 2281 const SCEV *NumWordsS = SE->getAddExpr(BECount32, OneS, SCEV::FlagNUW); 2282 Value *NumWords = Expander.expandCodeFor(NumWordsS, Int32Ty, 2283 MemmoveB->getTerminator()); 2284 if (Instruction *In = dyn_cast<Instruction>(NumWords)) 2285 if (Value *Simp = SimplifyInstruction(In, {*DL, TLI, DT})) 2286 NumWords = Simp; 2287 2288 Value *Op0 = (StoreBasePtr->getType() == Int32PtrTy) 2289 ? StoreBasePtr 2290 : CondBuilder.CreateBitCast(StoreBasePtr, Int32PtrTy); 2291 Value *Op1 = (LoadBasePtr->getType() == Int32PtrTy) 2292 ? LoadBasePtr 2293 : CondBuilder.CreateBitCast(LoadBasePtr, Int32PtrTy); 2294 NewCall = CondBuilder.CreateCall(Fn, {Op0, Op1, NumWords}); 2295 } else { 2296 NewCall = CondBuilder.CreateMemMove( 2297 StoreBasePtr, SI->getAlign(), LoadBasePtr, LI->getAlign(), NumBytes); 2298 } 2299 } else { 2300 NewCall = Builder.CreateMemCpy(StoreBasePtr, SI->getAlign(), LoadBasePtr, 2301 LI->getAlign(), NumBytes); 2302 // Okay, the memcpy has been formed. Zap the original store and 2303 // anything that feeds into it. 2304 RecursivelyDeleteTriviallyDeadInstructions(SI, TLI); 2305 } 2306 2307 NewCall->setDebugLoc(DLoc); 2308 2309 LLVM_DEBUG(dbgs() << " Formed " << (Overlap ? "memmove: " : "memcpy: ") 2310 << *NewCall << "\n" 2311 << " from load ptr=" << *LoadEv << " at: " << *LI << "\n" 2312 << " from store ptr=" << *StoreEv << " at: " << *SI 2313 << "\n"); 2314 2315 return true; 2316 } 2317 2318 // Check if the instructions in Insts, together with their dependencies 2319 // cover the loop in the sense that the loop could be safely eliminated once 2320 // the instructions in Insts are removed. 2321 bool HexagonLoopIdiomRecognize::coverLoop(Loop *L, 2322 SmallVectorImpl<Instruction*> &Insts) const { 2323 SmallSet<BasicBlock*,8> LoopBlocks; 2324 for (auto *B : L->blocks()) 2325 LoopBlocks.insert(B); 2326 2327 SetVector<Instruction*> Worklist(Insts.begin(), Insts.end()); 2328 2329 // Collect all instructions from the loop that the instructions in Insts 2330 // depend on (plus their dependencies, etc.). These instructions will 2331 // constitute the expression trees that feed those in Insts, but the trees 2332 // will be limited only to instructions contained in the loop. 2333 for (unsigned i = 0; i < Worklist.size(); ++i) { 2334 Instruction *In = Worklist[i]; 2335 for (auto I = In->op_begin(), E = In->op_end(); I != E; ++I) { 2336 Instruction *OpI = dyn_cast<Instruction>(I); 2337 if (!OpI) 2338 continue; 2339 BasicBlock *PB = OpI->getParent(); 2340 if (!LoopBlocks.count(PB)) 2341 continue; 2342 Worklist.insert(OpI); 2343 } 2344 } 2345 2346 // Scan all instructions in the loop, if any of them have a user outside 2347 // of the loop, or outside of the expressions collected above, then either 2348 // the loop has a side-effect visible outside of it, or there are 2349 // instructions in it that are not involved in the original set Insts. 2350 for (auto *B : L->blocks()) { 2351 for (auto &In : *B) { 2352 if (isa<BranchInst>(In) || isa<DbgInfoIntrinsic>(In)) 2353 continue; 2354 if (!Worklist.count(&In) && In.mayHaveSideEffects()) 2355 return false; 2356 for (auto K : In.users()) { 2357 Instruction *UseI = dyn_cast<Instruction>(K); 2358 if (!UseI) 2359 continue; 2360 BasicBlock *UseB = UseI->getParent(); 2361 if (LF->getLoopFor(UseB) != L) 2362 return false; 2363 } 2364 } 2365 } 2366 2367 return true; 2368 } 2369 2370 /// runOnLoopBlock - Process the specified block, which lives in a counted loop 2371 /// with the specified backedge count. This block is known to be in the current 2372 /// loop and not in any subloops. 2373 bool HexagonLoopIdiomRecognize::runOnLoopBlock(Loop *CurLoop, BasicBlock *BB, 2374 const SCEV *BECount, SmallVectorImpl<BasicBlock*> &ExitBlocks) { 2375 // We can only promote stores in this block if they are unconditionally 2376 // executed in the loop. For a block to be unconditionally executed, it has 2377 // to dominate all the exit blocks of the loop. Verify this now. 2378 auto DominatedByBB = [this,BB] (BasicBlock *EB) -> bool { 2379 return DT->dominates(BB, EB); 2380 }; 2381 if (!all_of(ExitBlocks, DominatedByBB)) 2382 return false; 2383 2384 bool MadeChange = false; 2385 // Look for store instructions, which may be optimized to memset/memcpy. 2386 SmallVector<StoreInst*,8> Stores; 2387 collectStores(CurLoop, BB, Stores); 2388 2389 // Optimize the store into a memcpy, if it feeds an similarly strided load. 2390 for (auto &SI : Stores) 2391 MadeChange |= processCopyingStore(CurLoop, SI, BECount); 2392 2393 return MadeChange; 2394 } 2395 2396 bool HexagonLoopIdiomRecognize::runOnCountableLoop(Loop *L) { 2397 PolynomialMultiplyRecognize PMR(L, *DL, *DT, *TLI, *SE); 2398 if (PMR.recognize()) 2399 return true; 2400 2401 if (!HasMemcpy && !HasMemmove) 2402 return false; 2403 2404 const SCEV *BECount = SE->getBackedgeTakenCount(L); 2405 assert(!isa<SCEVCouldNotCompute>(BECount) && 2406 "runOnCountableLoop() called on a loop without a predictable" 2407 "backedge-taken count"); 2408 2409 SmallVector<BasicBlock *, 8> ExitBlocks; 2410 L->getUniqueExitBlocks(ExitBlocks); 2411 2412 bool Changed = false; 2413 2414 // Scan all the blocks in the loop that are not in subloops. 2415 for (auto *BB : L->getBlocks()) { 2416 // Ignore blocks in subloops. 2417 if (LF->getLoopFor(BB) != L) 2418 continue; 2419 Changed |= runOnLoopBlock(L, BB, BECount, ExitBlocks); 2420 } 2421 2422 return Changed; 2423 } 2424 2425 bool HexagonLoopIdiomRecognize::run(Loop *L) { 2426 const Module &M = *L->getHeader()->getParent()->getParent(); 2427 if (Triple(M.getTargetTriple()).getArch() != Triple::hexagon) 2428 return false; 2429 2430 // If the loop could not be converted to canonical form, it must have an 2431 // indirectbr in it, just give up. 2432 if (!L->getLoopPreheader()) 2433 return false; 2434 2435 // Disable loop idiom recognition if the function's name is a common idiom. 2436 StringRef Name = L->getHeader()->getParent()->getName(); 2437 if (Name == "memset" || Name == "memcpy" || Name == "memmove") 2438 return false; 2439 2440 DL = &L->getHeader()->getModule()->getDataLayout(); 2441 2442 HasMemcpy = TLI->has(LibFunc_memcpy); 2443 HasMemmove = TLI->has(LibFunc_memmove); 2444 2445 if (SE->hasLoopInvariantBackedgeTakenCount(L)) 2446 return runOnCountableLoop(L); 2447 return false; 2448 } 2449 2450 bool HexagonLoopIdiomRecognizeLegacyPass::runOnLoop(Loop *L, 2451 LPPassManager &LPM) { 2452 if (skipLoop(L)) 2453 return false; 2454 2455 auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); 2456 auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 2457 auto *LF = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 2458 auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( 2459 *L->getHeader()->getParent()); 2460 auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 2461 return HexagonLoopIdiomRecognize(AA, DT, LF, TLI, SE).run(L); 2462 } 2463 2464 Pass *llvm::createHexagonLoopIdiomPass() { 2465 return new HexagonLoopIdiomRecognizeLegacyPass(); 2466 } 2467 2468 PreservedAnalyses 2469 HexagonLoopIdiomRecognitionPass::run(Loop &L, LoopAnalysisManager &AM, 2470 LoopStandardAnalysisResults &AR, 2471 LPMUpdater &U) { 2472 return HexagonLoopIdiomRecognize(&AR.AA, &AR.DT, &AR.LI, &AR.TLI, &AR.SE) 2473 .run(&L) 2474 ? getLoopPassPreservedAnalyses() 2475 : PreservedAnalyses::all(); 2476 } 2477