1 //===- HexagonLoopIdiomRecognition.cpp ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #define DEBUG_TYPE "hexagon-lir" 10 11 #include "llvm/ADT/APInt.h" 12 #include "llvm/ADT/DenseMap.h" 13 #include "llvm/ADT/SetVector.h" 14 #include "llvm/ADT/SmallPtrSet.h" 15 #include "llvm/ADT/SmallSet.h" 16 #include "llvm/ADT/SmallVector.h" 17 #include "llvm/ADT/StringRef.h" 18 #include "llvm/ADT/Triple.h" 19 #include "llvm/Analysis/AliasAnalysis.h" 20 #include "llvm/Analysis/InstructionSimplify.h" 21 #include "llvm/Analysis/LoopInfo.h" 22 #include "llvm/Analysis/LoopPass.h" 23 #include "llvm/Analysis/MemoryLocation.h" 24 #include "llvm/Analysis/ScalarEvolution.h" 25 #include "llvm/Analysis/ScalarEvolutionExpander.h" 26 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 27 #include "llvm/Analysis/TargetLibraryInfo.h" 28 #include "llvm/Transforms/Utils/Local.h" 29 #include "llvm/Analysis/ValueTracking.h" 30 #include "llvm/IR/Attributes.h" 31 #include "llvm/IR/BasicBlock.h" 32 #include "llvm/IR/Constant.h" 33 #include "llvm/IR/Constants.h" 34 #include "llvm/IR/DataLayout.h" 35 #include "llvm/IR/DebugLoc.h" 36 #include "llvm/IR/DerivedTypes.h" 37 #include "llvm/IR/Dominators.h" 38 #include "llvm/IR/Function.h" 39 #include "llvm/IR/IRBuilder.h" 40 #include "llvm/IR/InstrTypes.h" 41 #include "llvm/IR/Instruction.h" 42 #include "llvm/IR/Instructions.h" 43 #include "llvm/IR/IntrinsicInst.h" 44 #include "llvm/IR/Intrinsics.h" 45 #include "llvm/IR/Module.h" 46 #include "llvm/IR/PatternMatch.h" 47 #include "llvm/IR/Type.h" 48 #include "llvm/IR/User.h" 49 #include "llvm/IR/Value.h" 50 #include "llvm/Pass.h" 51 #include "llvm/Support/Casting.h" 52 #include "llvm/Support/CommandLine.h" 53 #include "llvm/Support/Compiler.h" 54 #include "llvm/Support/Debug.h" 55 #include "llvm/Support/ErrorHandling.h" 56 #include "llvm/Support/KnownBits.h" 57 #include "llvm/Support/raw_ostream.h" 58 #include "llvm/Transforms/Scalar.h" 59 #include "llvm/Transforms/Utils.h" 60 #include <algorithm> 61 #include <array> 62 #include <cassert> 63 #include <cstdint> 64 #include <cstdlib> 65 #include <deque> 66 #include <functional> 67 #include <iterator> 68 #include <map> 69 #include <set> 70 #include <utility> 71 #include <vector> 72 73 using namespace llvm; 74 75 static cl::opt<bool> DisableMemcpyIdiom("disable-memcpy-idiom", 76 cl::Hidden, cl::init(false), 77 cl::desc("Disable generation of memcpy in loop idiom recognition")); 78 79 static cl::opt<bool> DisableMemmoveIdiom("disable-memmove-idiom", 80 cl::Hidden, cl::init(false), 81 cl::desc("Disable generation of memmove in loop idiom recognition")); 82 83 static cl::opt<unsigned> RuntimeMemSizeThreshold("runtime-mem-idiom-threshold", 84 cl::Hidden, cl::init(0), cl::desc("Threshold (in bytes) for the runtime " 85 "check guarding the memmove.")); 86 87 static cl::opt<unsigned> CompileTimeMemSizeThreshold( 88 "compile-time-mem-idiom-threshold", cl::Hidden, cl::init(64), 89 cl::desc("Threshold (in bytes) to perform the transformation, if the " 90 "runtime loop count (mem transfer size) is known at compile-time.")); 91 92 static cl::opt<bool> OnlyNonNestedMemmove("only-nonnested-memmove-idiom", 93 cl::Hidden, cl::init(true), 94 cl::desc("Only enable generating memmove in non-nested loops")); 95 96 cl::opt<bool> HexagonVolatileMemcpy("disable-hexagon-volatile-memcpy", 97 cl::Hidden, cl::init(false), 98 cl::desc("Enable Hexagon-specific memcpy for volatile destination.")); 99 100 static cl::opt<unsigned> SimplifyLimit("hlir-simplify-limit", cl::init(10000), 101 cl::Hidden, cl::desc("Maximum number of simplification steps in HLIR")); 102 103 static const char *HexagonVolatileMemcpyName 104 = "hexagon_memcpy_forward_vp4cp4n2"; 105 106 107 namespace llvm { 108 109 void initializeHexagonLoopIdiomRecognizePass(PassRegistry&); 110 Pass *createHexagonLoopIdiomPass(); 111 112 } // end namespace llvm 113 114 namespace { 115 116 class HexagonLoopIdiomRecognize : public LoopPass { 117 public: 118 static char ID; 119 120 explicit HexagonLoopIdiomRecognize() : LoopPass(ID) { 121 initializeHexagonLoopIdiomRecognizePass(*PassRegistry::getPassRegistry()); 122 } 123 124 StringRef getPassName() const override { 125 return "Recognize Hexagon-specific loop idioms"; 126 } 127 128 void getAnalysisUsage(AnalysisUsage &AU) const override { 129 AU.addRequired<LoopInfoWrapperPass>(); 130 AU.addRequiredID(LoopSimplifyID); 131 AU.addRequiredID(LCSSAID); 132 AU.addRequired<AAResultsWrapperPass>(); 133 AU.addPreserved<AAResultsWrapperPass>(); 134 AU.addRequired<ScalarEvolutionWrapperPass>(); 135 AU.addRequired<DominatorTreeWrapperPass>(); 136 AU.addRequired<TargetLibraryInfoWrapperPass>(); 137 AU.addPreserved<TargetLibraryInfoWrapperPass>(); 138 } 139 140 bool runOnLoop(Loop *L, LPPassManager &LPM) override; 141 142 private: 143 int getSCEVStride(const SCEVAddRecExpr *StoreEv); 144 bool isLegalStore(Loop *CurLoop, StoreInst *SI); 145 void collectStores(Loop *CurLoop, BasicBlock *BB, 146 SmallVectorImpl<StoreInst*> &Stores); 147 bool processCopyingStore(Loop *CurLoop, StoreInst *SI, const SCEV *BECount); 148 bool coverLoop(Loop *L, SmallVectorImpl<Instruction*> &Insts) const; 149 bool runOnLoopBlock(Loop *CurLoop, BasicBlock *BB, const SCEV *BECount, 150 SmallVectorImpl<BasicBlock*> &ExitBlocks); 151 bool runOnCountableLoop(Loop *L); 152 153 AliasAnalysis *AA; 154 const DataLayout *DL; 155 DominatorTree *DT; 156 LoopInfo *LF; 157 const TargetLibraryInfo *TLI; 158 ScalarEvolution *SE; 159 bool HasMemcpy, HasMemmove; 160 }; 161 162 struct Simplifier { 163 struct Rule { 164 using FuncType = std::function<Value* (Instruction*, LLVMContext&)>; 165 Rule(StringRef N, FuncType F) : Name(N), Fn(F) {} 166 StringRef Name; // For debugging. 167 FuncType Fn; 168 }; 169 170 void addRule(StringRef N, const Rule::FuncType &F) { 171 Rules.push_back(Rule(N, F)); 172 } 173 174 private: 175 struct WorkListType { 176 WorkListType() = default; 177 178 void push_back(Value* V) { 179 // Do not push back duplicates. 180 if (!S.count(V)) { Q.push_back(V); S.insert(V); } 181 } 182 183 Value *pop_front_val() { 184 Value *V = Q.front(); Q.pop_front(); S.erase(V); 185 return V; 186 } 187 188 bool empty() const { return Q.empty(); } 189 190 private: 191 std::deque<Value*> Q; 192 std::set<Value*> S; 193 }; 194 195 using ValueSetType = std::set<Value *>; 196 197 std::vector<Rule> Rules; 198 199 public: 200 struct Context { 201 using ValueMapType = DenseMap<Value *, Value *>; 202 203 Value *Root; 204 ValueSetType Used; // The set of all cloned values used by Root. 205 ValueSetType Clones; // The set of all cloned values. 206 LLVMContext &Ctx; 207 208 Context(Instruction *Exp) 209 : Ctx(Exp->getParent()->getParent()->getContext()) { 210 initialize(Exp); 211 } 212 213 ~Context() { cleanup(); } 214 215 void print(raw_ostream &OS, const Value *V) const; 216 Value *materialize(BasicBlock *B, BasicBlock::iterator At); 217 218 private: 219 friend struct Simplifier; 220 221 void initialize(Instruction *Exp); 222 void cleanup(); 223 224 template <typename FuncT> void traverse(Value *V, FuncT F); 225 void record(Value *V); 226 void use(Value *V); 227 void unuse(Value *V); 228 229 bool equal(const Instruction *I, const Instruction *J) const; 230 Value *find(Value *Tree, Value *Sub) const; 231 Value *subst(Value *Tree, Value *OldV, Value *NewV); 232 void replace(Value *OldV, Value *NewV); 233 void link(Instruction *I, BasicBlock *B, BasicBlock::iterator At); 234 }; 235 236 Value *simplify(Context &C); 237 }; 238 239 struct PE { 240 PE(const Simplifier::Context &c, Value *v = nullptr) : C(c), V(v) {} 241 242 const Simplifier::Context &C; 243 const Value *V; 244 }; 245 246 LLVM_ATTRIBUTE_USED 247 raw_ostream &operator<<(raw_ostream &OS, const PE &P) { 248 P.C.print(OS, P.V ? P.V : P.C.Root); 249 return OS; 250 } 251 252 } // end anonymous namespace 253 254 char HexagonLoopIdiomRecognize::ID = 0; 255 256 INITIALIZE_PASS_BEGIN(HexagonLoopIdiomRecognize, "hexagon-loop-idiom", 257 "Recognize Hexagon-specific loop idioms", false, false) 258 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 259 INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 260 INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) 261 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) 262 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 263 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 264 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 265 INITIALIZE_PASS_END(HexagonLoopIdiomRecognize, "hexagon-loop-idiom", 266 "Recognize Hexagon-specific loop idioms", false, false) 267 268 template <typename FuncT> 269 void Simplifier::Context::traverse(Value *V, FuncT F) { 270 WorkListType Q; 271 Q.push_back(V); 272 273 while (!Q.empty()) { 274 Instruction *U = dyn_cast<Instruction>(Q.pop_front_val()); 275 if (!U || U->getParent()) 276 continue; 277 if (!F(U)) 278 continue; 279 for (Value *Op : U->operands()) 280 Q.push_back(Op); 281 } 282 } 283 284 void Simplifier::Context::print(raw_ostream &OS, const Value *V) const { 285 const auto *U = dyn_cast<const Instruction>(V); 286 if (!U) { 287 OS << V << '(' << *V << ')'; 288 return; 289 } 290 291 if (U->getParent()) { 292 OS << U << '('; 293 U->printAsOperand(OS, true); 294 OS << ')'; 295 return; 296 } 297 298 unsigned N = U->getNumOperands(); 299 if (N != 0) 300 OS << U << '('; 301 OS << U->getOpcodeName(); 302 for (const Value *Op : U->operands()) { 303 OS << ' '; 304 print(OS, Op); 305 } 306 if (N != 0) 307 OS << ')'; 308 } 309 310 void Simplifier::Context::initialize(Instruction *Exp) { 311 // Perform a deep clone of the expression, set Root to the root 312 // of the clone, and build a map from the cloned values to the 313 // original ones. 314 ValueMapType M; 315 BasicBlock *Block = Exp->getParent(); 316 WorkListType Q; 317 Q.push_back(Exp); 318 319 while (!Q.empty()) { 320 Value *V = Q.pop_front_val(); 321 if (M.find(V) != M.end()) 322 continue; 323 if (Instruction *U = dyn_cast<Instruction>(V)) { 324 if (isa<PHINode>(U) || U->getParent() != Block) 325 continue; 326 for (Value *Op : U->operands()) 327 Q.push_back(Op); 328 M.insert({U, U->clone()}); 329 } 330 } 331 332 for (std::pair<Value*,Value*> P : M) { 333 Instruction *U = cast<Instruction>(P.second); 334 for (unsigned i = 0, n = U->getNumOperands(); i != n; ++i) { 335 auto F = M.find(U->getOperand(i)); 336 if (F != M.end()) 337 U->setOperand(i, F->second); 338 } 339 } 340 341 auto R = M.find(Exp); 342 assert(R != M.end()); 343 Root = R->second; 344 345 record(Root); 346 use(Root); 347 } 348 349 void Simplifier::Context::record(Value *V) { 350 auto Record = [this](Instruction *U) -> bool { 351 Clones.insert(U); 352 return true; 353 }; 354 traverse(V, Record); 355 } 356 357 void Simplifier::Context::use(Value *V) { 358 auto Use = [this](Instruction *U) -> bool { 359 Used.insert(U); 360 return true; 361 }; 362 traverse(V, Use); 363 } 364 365 void Simplifier::Context::unuse(Value *V) { 366 if (!isa<Instruction>(V) || cast<Instruction>(V)->getParent() != nullptr) 367 return; 368 369 auto Unuse = [this](Instruction *U) -> bool { 370 if (!U->use_empty()) 371 return false; 372 Used.erase(U); 373 return true; 374 }; 375 traverse(V, Unuse); 376 } 377 378 Value *Simplifier::Context::subst(Value *Tree, Value *OldV, Value *NewV) { 379 if (Tree == OldV) 380 return NewV; 381 if (OldV == NewV) 382 return Tree; 383 384 WorkListType Q; 385 Q.push_back(Tree); 386 while (!Q.empty()) { 387 Instruction *U = dyn_cast<Instruction>(Q.pop_front_val()); 388 // If U is not an instruction, or it's not a clone, skip it. 389 if (!U || U->getParent()) 390 continue; 391 for (unsigned i = 0, n = U->getNumOperands(); i != n; ++i) { 392 Value *Op = U->getOperand(i); 393 if (Op == OldV) { 394 U->setOperand(i, NewV); 395 unuse(OldV); 396 } else { 397 Q.push_back(Op); 398 } 399 } 400 } 401 return Tree; 402 } 403 404 void Simplifier::Context::replace(Value *OldV, Value *NewV) { 405 if (Root == OldV) { 406 Root = NewV; 407 use(Root); 408 return; 409 } 410 411 // NewV may be a complex tree that has just been created by one of the 412 // transformation rules. We need to make sure that it is commoned with 413 // the existing Root to the maximum extent possible. 414 // Identify all subtrees of NewV (including NewV itself) that have 415 // equivalent counterparts in Root, and replace those subtrees with 416 // these counterparts. 417 WorkListType Q; 418 Q.push_back(NewV); 419 while (!Q.empty()) { 420 Value *V = Q.pop_front_val(); 421 Instruction *U = dyn_cast<Instruction>(V); 422 if (!U || U->getParent()) 423 continue; 424 if (Value *DupV = find(Root, V)) { 425 if (DupV != V) 426 NewV = subst(NewV, V, DupV); 427 } else { 428 for (Value *Op : U->operands()) 429 Q.push_back(Op); 430 } 431 } 432 433 // Now, simply replace OldV with NewV in Root. 434 Root = subst(Root, OldV, NewV); 435 use(Root); 436 } 437 438 void Simplifier::Context::cleanup() { 439 for (Value *V : Clones) { 440 Instruction *U = cast<Instruction>(V); 441 if (!U->getParent()) 442 U->dropAllReferences(); 443 } 444 445 for (Value *V : Clones) { 446 Instruction *U = cast<Instruction>(V); 447 if (!U->getParent()) 448 U->deleteValue(); 449 } 450 } 451 452 bool Simplifier::Context::equal(const Instruction *I, 453 const Instruction *J) const { 454 if (I == J) 455 return true; 456 if (!I->isSameOperationAs(J)) 457 return false; 458 if (isa<PHINode>(I)) 459 return I->isIdenticalTo(J); 460 461 for (unsigned i = 0, n = I->getNumOperands(); i != n; ++i) { 462 Value *OpI = I->getOperand(i), *OpJ = J->getOperand(i); 463 if (OpI == OpJ) 464 continue; 465 auto *InI = dyn_cast<const Instruction>(OpI); 466 auto *InJ = dyn_cast<const Instruction>(OpJ); 467 if (InI && InJ) { 468 if (!equal(InI, InJ)) 469 return false; 470 } else if (InI != InJ || !InI) 471 return false; 472 } 473 return true; 474 } 475 476 Value *Simplifier::Context::find(Value *Tree, Value *Sub) const { 477 Instruction *SubI = dyn_cast<Instruction>(Sub); 478 WorkListType Q; 479 Q.push_back(Tree); 480 481 while (!Q.empty()) { 482 Value *V = Q.pop_front_val(); 483 if (V == Sub) 484 return V; 485 Instruction *U = dyn_cast<Instruction>(V); 486 if (!U || U->getParent()) 487 continue; 488 if (SubI && equal(SubI, U)) 489 return U; 490 assert(!isa<PHINode>(U)); 491 for (Value *Op : U->operands()) 492 Q.push_back(Op); 493 } 494 return nullptr; 495 } 496 497 void Simplifier::Context::link(Instruction *I, BasicBlock *B, 498 BasicBlock::iterator At) { 499 if (I->getParent()) 500 return; 501 502 for (Value *Op : I->operands()) { 503 if (Instruction *OpI = dyn_cast<Instruction>(Op)) 504 link(OpI, B, At); 505 } 506 507 B->getInstList().insert(At, I); 508 } 509 510 Value *Simplifier::Context::materialize(BasicBlock *B, 511 BasicBlock::iterator At) { 512 if (Instruction *RootI = dyn_cast<Instruction>(Root)) 513 link(RootI, B, At); 514 return Root; 515 } 516 517 Value *Simplifier::simplify(Context &C) { 518 WorkListType Q; 519 Q.push_back(C.Root); 520 unsigned Count = 0; 521 const unsigned Limit = SimplifyLimit; 522 523 while (!Q.empty()) { 524 if (Count++ >= Limit) 525 break; 526 Instruction *U = dyn_cast<Instruction>(Q.pop_front_val()); 527 if (!U || U->getParent() || !C.Used.count(U)) 528 continue; 529 bool Changed = false; 530 for (Rule &R : Rules) { 531 Value *W = R.Fn(U, C.Ctx); 532 if (!W) 533 continue; 534 Changed = true; 535 C.record(W); 536 C.replace(U, W); 537 Q.push_back(C.Root); 538 break; 539 } 540 if (!Changed) { 541 for (Value *Op : U->operands()) 542 Q.push_back(Op); 543 } 544 } 545 return Count < Limit ? C.Root : nullptr; 546 } 547 548 //===----------------------------------------------------------------------===// 549 // 550 // Implementation of PolynomialMultiplyRecognize 551 // 552 //===----------------------------------------------------------------------===// 553 554 namespace { 555 556 class PolynomialMultiplyRecognize { 557 public: 558 explicit PolynomialMultiplyRecognize(Loop *loop, const DataLayout &dl, 559 const DominatorTree &dt, const TargetLibraryInfo &tli, 560 ScalarEvolution &se) 561 : CurLoop(loop), DL(dl), DT(dt), TLI(tli), SE(se) {} 562 563 bool recognize(); 564 565 private: 566 using ValueSeq = SetVector<Value *>; 567 568 IntegerType *getPmpyType() const { 569 LLVMContext &Ctx = CurLoop->getHeader()->getParent()->getContext(); 570 return IntegerType::get(Ctx, 32); 571 } 572 573 bool isPromotableTo(Value *V, IntegerType *Ty); 574 void promoteTo(Instruction *In, IntegerType *DestTy, BasicBlock *LoopB); 575 bool promoteTypes(BasicBlock *LoopB, BasicBlock *ExitB); 576 577 Value *getCountIV(BasicBlock *BB); 578 bool findCycle(Value *Out, Value *In, ValueSeq &Cycle); 579 void classifyCycle(Instruction *DivI, ValueSeq &Cycle, ValueSeq &Early, 580 ValueSeq &Late); 581 bool classifyInst(Instruction *UseI, ValueSeq &Early, ValueSeq &Late); 582 bool commutesWithShift(Instruction *I); 583 bool highBitsAreZero(Value *V, unsigned IterCount); 584 bool keepsHighBitsZero(Value *V, unsigned IterCount); 585 bool isOperandShifted(Instruction *I, Value *Op); 586 bool convertShiftsToLeft(BasicBlock *LoopB, BasicBlock *ExitB, 587 unsigned IterCount); 588 void cleanupLoopBody(BasicBlock *LoopB); 589 590 struct ParsedValues { 591 ParsedValues() = default; 592 593 Value *M = nullptr; 594 Value *P = nullptr; 595 Value *Q = nullptr; 596 Value *R = nullptr; 597 Value *X = nullptr; 598 Instruction *Res = nullptr; 599 unsigned IterCount = 0; 600 bool Left = false; 601 bool Inv = false; 602 }; 603 604 bool matchLeftShift(SelectInst *SelI, Value *CIV, ParsedValues &PV); 605 bool matchRightShift(SelectInst *SelI, ParsedValues &PV); 606 bool scanSelect(SelectInst *SI, BasicBlock *LoopB, BasicBlock *PrehB, 607 Value *CIV, ParsedValues &PV, bool PreScan); 608 unsigned getInverseMxN(unsigned QP); 609 Value *generate(BasicBlock::iterator At, ParsedValues &PV); 610 611 void setupPreSimplifier(Simplifier &S); 612 void setupPostSimplifier(Simplifier &S); 613 614 Loop *CurLoop; 615 const DataLayout &DL; 616 const DominatorTree &DT; 617 const TargetLibraryInfo &TLI; 618 ScalarEvolution &SE; 619 }; 620 621 } // end anonymous namespace 622 623 Value *PolynomialMultiplyRecognize::getCountIV(BasicBlock *BB) { 624 pred_iterator PI = pred_begin(BB), PE = pred_end(BB); 625 if (std::distance(PI, PE) != 2) 626 return nullptr; 627 BasicBlock *PB = (*PI == BB) ? *std::next(PI) : *PI; 628 629 for (auto I = BB->begin(), E = BB->end(); I != E && isa<PHINode>(I); ++I) { 630 auto *PN = cast<PHINode>(I); 631 Value *InitV = PN->getIncomingValueForBlock(PB); 632 if (!isa<ConstantInt>(InitV) || !cast<ConstantInt>(InitV)->isZero()) 633 continue; 634 Value *IterV = PN->getIncomingValueForBlock(BB); 635 if (!isa<BinaryOperator>(IterV)) 636 continue; 637 auto *BO = dyn_cast<BinaryOperator>(IterV); 638 if (BO->getOpcode() != Instruction::Add) 639 continue; 640 Value *IncV = nullptr; 641 if (BO->getOperand(0) == PN) 642 IncV = BO->getOperand(1); 643 else if (BO->getOperand(1) == PN) 644 IncV = BO->getOperand(0); 645 if (IncV == nullptr) 646 continue; 647 648 if (auto *T = dyn_cast<ConstantInt>(IncV)) 649 if (T->getZExtValue() == 1) 650 return PN; 651 } 652 return nullptr; 653 } 654 655 static void replaceAllUsesOfWithIn(Value *I, Value *J, BasicBlock *BB) { 656 for (auto UI = I->user_begin(), UE = I->user_end(); UI != UE;) { 657 Use &TheUse = UI.getUse(); 658 ++UI; 659 if (auto *II = dyn_cast<Instruction>(TheUse.getUser())) 660 if (BB == II->getParent()) 661 II->replaceUsesOfWith(I, J); 662 } 663 } 664 665 bool PolynomialMultiplyRecognize::matchLeftShift(SelectInst *SelI, 666 Value *CIV, ParsedValues &PV) { 667 // Match the following: 668 // select (X & (1 << i)) != 0 ? R ^ (Q << i) : R 669 // select (X & (1 << i)) == 0 ? R : R ^ (Q << i) 670 // The condition may also check for equality with the masked value, i.e 671 // select (X & (1 << i)) == (1 << i) ? R ^ (Q << i) : R 672 // select (X & (1 << i)) != (1 << i) ? R : R ^ (Q << i); 673 674 Value *CondV = SelI->getCondition(); 675 Value *TrueV = SelI->getTrueValue(); 676 Value *FalseV = SelI->getFalseValue(); 677 678 using namespace PatternMatch; 679 680 CmpInst::Predicate P; 681 Value *A = nullptr, *B = nullptr, *C = nullptr; 682 683 if (!match(CondV, m_ICmp(P, m_And(m_Value(A), m_Value(B)), m_Value(C))) && 684 !match(CondV, m_ICmp(P, m_Value(C), m_And(m_Value(A), m_Value(B))))) 685 return false; 686 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE) 687 return false; 688 // Matched: select (A & B) == C ? ... : ... 689 // select (A & B) != C ? ... : ... 690 691 Value *X = nullptr, *Sh1 = nullptr; 692 // Check (A & B) for (X & (1 << i)): 693 if (match(A, m_Shl(m_One(), m_Specific(CIV)))) { 694 Sh1 = A; 695 X = B; 696 } else if (match(B, m_Shl(m_One(), m_Specific(CIV)))) { 697 Sh1 = B; 698 X = A; 699 } else { 700 // TODO: Could also check for an induction variable containing single 701 // bit shifted left by 1 in each iteration. 702 return false; 703 } 704 705 bool TrueIfZero; 706 707 // Check C against the possible values for comparison: 0 and (1 << i): 708 if (match(C, m_Zero())) 709 TrueIfZero = (P == CmpInst::ICMP_EQ); 710 else if (C == Sh1) 711 TrueIfZero = (P == CmpInst::ICMP_NE); 712 else 713 return false; 714 715 // So far, matched: 716 // select (X & (1 << i)) ? ... : ... 717 // including variations of the check against zero/non-zero value. 718 719 Value *ShouldSameV = nullptr, *ShouldXoredV = nullptr; 720 if (TrueIfZero) { 721 ShouldSameV = TrueV; 722 ShouldXoredV = FalseV; 723 } else { 724 ShouldSameV = FalseV; 725 ShouldXoredV = TrueV; 726 } 727 728 Value *Q = nullptr, *R = nullptr, *Y = nullptr, *Z = nullptr; 729 Value *T = nullptr; 730 if (match(ShouldXoredV, m_Xor(m_Value(Y), m_Value(Z)))) { 731 // Matched: select +++ ? ... : Y ^ Z 732 // select +++ ? Y ^ Z : ... 733 // where +++ denotes previously checked matches. 734 if (ShouldSameV == Y) 735 T = Z; 736 else if (ShouldSameV == Z) 737 T = Y; 738 else 739 return false; 740 R = ShouldSameV; 741 // Matched: select +++ ? R : R ^ T 742 // select +++ ? R ^ T : R 743 // depending on TrueIfZero. 744 745 } else if (match(ShouldSameV, m_Zero())) { 746 // Matched: select +++ ? 0 : ... 747 // select +++ ? ... : 0 748 if (!SelI->hasOneUse()) 749 return false; 750 T = ShouldXoredV; 751 // Matched: select +++ ? 0 : T 752 // select +++ ? T : 0 753 754 Value *U = *SelI->user_begin(); 755 if (!match(U, m_Xor(m_Specific(SelI), m_Value(R))) && 756 !match(U, m_Xor(m_Value(R), m_Specific(SelI)))) 757 return false; 758 // Matched: xor (select +++ ? 0 : T), R 759 // xor (select +++ ? T : 0), R 760 } else 761 return false; 762 763 // The xor input value T is isolated into its own match so that it could 764 // be checked against an induction variable containing a shifted bit 765 // (todo). 766 // For now, check against (Q << i). 767 if (!match(T, m_Shl(m_Value(Q), m_Specific(CIV))) && 768 !match(T, m_Shl(m_ZExt(m_Value(Q)), m_ZExt(m_Specific(CIV))))) 769 return false; 770 // Matched: select +++ ? R : R ^ (Q << i) 771 // select +++ ? R ^ (Q << i) : R 772 773 PV.X = X; 774 PV.Q = Q; 775 PV.R = R; 776 PV.Left = true; 777 return true; 778 } 779 780 bool PolynomialMultiplyRecognize::matchRightShift(SelectInst *SelI, 781 ParsedValues &PV) { 782 // Match the following: 783 // select (X & 1) != 0 ? (R >> 1) ^ Q : (R >> 1) 784 // select (X & 1) == 0 ? (R >> 1) : (R >> 1) ^ Q 785 // The condition may also check for equality with the masked value, i.e 786 // select (X & 1) == 1 ? (R >> 1) ^ Q : (R >> 1) 787 // select (X & 1) != 1 ? (R >> 1) : (R >> 1) ^ Q 788 789 Value *CondV = SelI->getCondition(); 790 Value *TrueV = SelI->getTrueValue(); 791 Value *FalseV = SelI->getFalseValue(); 792 793 using namespace PatternMatch; 794 795 Value *C = nullptr; 796 CmpInst::Predicate P; 797 bool TrueIfZero; 798 799 if (match(CondV, m_ICmp(P, m_Value(C), m_Zero())) || 800 match(CondV, m_ICmp(P, m_Zero(), m_Value(C)))) { 801 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE) 802 return false; 803 // Matched: select C == 0 ? ... : ... 804 // select C != 0 ? ... : ... 805 TrueIfZero = (P == CmpInst::ICMP_EQ); 806 } else if (match(CondV, m_ICmp(P, m_Value(C), m_One())) || 807 match(CondV, m_ICmp(P, m_One(), m_Value(C)))) { 808 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE) 809 return false; 810 // Matched: select C == 1 ? ... : ... 811 // select C != 1 ? ... : ... 812 TrueIfZero = (P == CmpInst::ICMP_NE); 813 } else 814 return false; 815 816 Value *X = nullptr; 817 if (!match(C, m_And(m_Value(X), m_One())) && 818 !match(C, m_And(m_One(), m_Value(X)))) 819 return false; 820 // Matched: select (X & 1) == +++ ? ... : ... 821 // select (X & 1) != +++ ? ... : ... 822 823 Value *R = nullptr, *Q = nullptr; 824 if (TrueIfZero) { 825 // The select's condition is true if the tested bit is 0. 826 // TrueV must be the shift, FalseV must be the xor. 827 if (!match(TrueV, m_LShr(m_Value(R), m_One()))) 828 return false; 829 // Matched: select +++ ? (R >> 1) : ... 830 if (!match(FalseV, m_Xor(m_Specific(TrueV), m_Value(Q))) && 831 !match(FalseV, m_Xor(m_Value(Q), m_Specific(TrueV)))) 832 return false; 833 // Matched: select +++ ? (R >> 1) : (R >> 1) ^ Q 834 // with commuting ^. 835 } else { 836 // The select's condition is true if the tested bit is 1. 837 // TrueV must be the xor, FalseV must be the shift. 838 if (!match(FalseV, m_LShr(m_Value(R), m_One()))) 839 return false; 840 // Matched: select +++ ? ... : (R >> 1) 841 if (!match(TrueV, m_Xor(m_Specific(FalseV), m_Value(Q))) && 842 !match(TrueV, m_Xor(m_Value(Q), m_Specific(FalseV)))) 843 return false; 844 // Matched: select +++ ? (R >> 1) ^ Q : (R >> 1) 845 // with commuting ^. 846 } 847 848 PV.X = X; 849 PV.Q = Q; 850 PV.R = R; 851 PV.Left = false; 852 return true; 853 } 854 855 bool PolynomialMultiplyRecognize::scanSelect(SelectInst *SelI, 856 BasicBlock *LoopB, BasicBlock *PrehB, Value *CIV, ParsedValues &PV, 857 bool PreScan) { 858 using namespace PatternMatch; 859 860 // The basic pattern for R = P.Q is: 861 // for i = 0..31 862 // R = phi (0, R') 863 // if (P & (1 << i)) ; test-bit(P, i) 864 // R' = R ^ (Q << i) 865 // 866 // Similarly, the basic pattern for R = (P/Q).Q - P 867 // for i = 0..31 868 // R = phi(P, R') 869 // if (R & (1 << i)) 870 // R' = R ^ (Q << i) 871 872 // There exist idioms, where instead of Q being shifted left, P is shifted 873 // right. This produces a result that is shifted right by 32 bits (the 874 // non-shifted result is 64-bit). 875 // 876 // For R = P.Q, this would be: 877 // for i = 0..31 878 // R = phi (0, R') 879 // if ((P >> i) & 1) 880 // R' = (R >> 1) ^ Q ; R is cycled through the loop, so it must 881 // else ; be shifted by 1, not i. 882 // R' = R >> 1 883 // 884 // And for the inverse: 885 // for i = 0..31 886 // R = phi (P, R') 887 // if (R & 1) 888 // R' = (R >> 1) ^ Q 889 // else 890 // R' = R >> 1 891 892 // The left-shifting idioms share the same pattern: 893 // select (X & (1 << i)) ? R ^ (Q << i) : R 894 // Similarly for right-shifting idioms: 895 // select (X & 1) ? (R >> 1) ^ Q 896 897 if (matchLeftShift(SelI, CIV, PV)) { 898 // If this is a pre-scan, getting this far is sufficient. 899 if (PreScan) 900 return true; 901 902 // Need to make sure that the SelI goes back into R. 903 auto *RPhi = dyn_cast<PHINode>(PV.R); 904 if (!RPhi) 905 return false; 906 if (SelI != RPhi->getIncomingValueForBlock(LoopB)) 907 return false; 908 PV.Res = SelI; 909 910 // If X is loop invariant, it must be the input polynomial, and the 911 // idiom is the basic polynomial multiply. 912 if (CurLoop->isLoopInvariant(PV.X)) { 913 PV.P = PV.X; 914 PV.Inv = false; 915 } else { 916 // X is not loop invariant. If X == R, this is the inverse pmpy. 917 // Otherwise, check for an xor with an invariant value. If the 918 // variable argument to the xor is R, then this is still a valid 919 // inverse pmpy. 920 PV.Inv = true; 921 if (PV.X != PV.R) { 922 Value *Var = nullptr, *Inv = nullptr, *X1 = nullptr, *X2 = nullptr; 923 if (!match(PV.X, m_Xor(m_Value(X1), m_Value(X2)))) 924 return false; 925 auto *I1 = dyn_cast<Instruction>(X1); 926 auto *I2 = dyn_cast<Instruction>(X2); 927 if (!I1 || I1->getParent() != LoopB) { 928 Var = X2; 929 Inv = X1; 930 } else if (!I2 || I2->getParent() != LoopB) { 931 Var = X1; 932 Inv = X2; 933 } else 934 return false; 935 if (Var != PV.R) 936 return false; 937 PV.M = Inv; 938 } 939 // The input polynomial P still needs to be determined. It will be 940 // the entry value of R. 941 Value *EntryP = RPhi->getIncomingValueForBlock(PrehB); 942 PV.P = EntryP; 943 } 944 945 return true; 946 } 947 948 if (matchRightShift(SelI, PV)) { 949 // If this is an inverse pattern, the Q polynomial must be known at 950 // compile time. 951 if (PV.Inv && !isa<ConstantInt>(PV.Q)) 952 return false; 953 if (PreScan) 954 return true; 955 // There is no exact matching of right-shift pmpy. 956 return false; 957 } 958 959 return false; 960 } 961 962 bool PolynomialMultiplyRecognize::isPromotableTo(Value *Val, 963 IntegerType *DestTy) { 964 IntegerType *T = dyn_cast<IntegerType>(Val->getType()); 965 if (!T || T->getBitWidth() > DestTy->getBitWidth()) 966 return false; 967 if (T->getBitWidth() == DestTy->getBitWidth()) 968 return true; 969 // Non-instructions are promotable. The reason why an instruction may not 970 // be promotable is that it may produce a different result if its operands 971 // and the result are promoted, for example, it may produce more non-zero 972 // bits. While it would still be possible to represent the proper result 973 // in a wider type, it may require adding additional instructions (which 974 // we don't want to do). 975 Instruction *In = dyn_cast<Instruction>(Val); 976 if (!In) 977 return true; 978 // The bitwidth of the source type is smaller than the destination. 979 // Check if the individual operation can be promoted. 980 switch (In->getOpcode()) { 981 case Instruction::PHI: 982 case Instruction::ZExt: 983 case Instruction::And: 984 case Instruction::Or: 985 case Instruction::Xor: 986 case Instruction::LShr: // Shift right is ok. 987 case Instruction::Select: 988 case Instruction::Trunc: 989 return true; 990 case Instruction::ICmp: 991 if (CmpInst *CI = cast<CmpInst>(In)) 992 return CI->isEquality() || CI->isUnsigned(); 993 llvm_unreachable("Cast failed unexpectedly"); 994 case Instruction::Add: 995 return In->hasNoSignedWrap() && In->hasNoUnsignedWrap(); 996 } 997 return false; 998 } 999 1000 void PolynomialMultiplyRecognize::promoteTo(Instruction *In, 1001 IntegerType *DestTy, BasicBlock *LoopB) { 1002 Type *OrigTy = In->getType(); 1003 assert(!OrigTy->isVoidTy() && "Invalid instruction to promote"); 1004 1005 // Leave boolean values alone. 1006 if (!In->getType()->isIntegerTy(1)) 1007 In->mutateType(DestTy); 1008 unsigned DestBW = DestTy->getBitWidth(); 1009 1010 // Handle PHIs. 1011 if (PHINode *P = dyn_cast<PHINode>(In)) { 1012 unsigned N = P->getNumIncomingValues(); 1013 for (unsigned i = 0; i != N; ++i) { 1014 BasicBlock *InB = P->getIncomingBlock(i); 1015 if (InB == LoopB) 1016 continue; 1017 Value *InV = P->getIncomingValue(i); 1018 IntegerType *Ty = cast<IntegerType>(InV->getType()); 1019 // Do not promote values in PHI nodes of type i1. 1020 if (Ty != P->getType()) { 1021 // If the value type does not match the PHI type, the PHI type 1022 // must have been promoted. 1023 assert(Ty->getBitWidth() < DestBW); 1024 InV = IRBuilder<>(InB->getTerminator()).CreateZExt(InV, DestTy); 1025 P->setIncomingValue(i, InV); 1026 } 1027 } 1028 } else if (ZExtInst *Z = dyn_cast<ZExtInst>(In)) { 1029 Value *Op = Z->getOperand(0); 1030 if (Op->getType() == Z->getType()) 1031 Z->replaceAllUsesWith(Op); 1032 Z->eraseFromParent(); 1033 return; 1034 } 1035 if (TruncInst *T = dyn_cast<TruncInst>(In)) { 1036 IntegerType *TruncTy = cast<IntegerType>(OrigTy); 1037 Value *Mask = ConstantInt::get(DestTy, (1u << TruncTy->getBitWidth()) - 1); 1038 Value *And = IRBuilder<>(In).CreateAnd(T->getOperand(0), Mask); 1039 T->replaceAllUsesWith(And); 1040 T->eraseFromParent(); 1041 return; 1042 } 1043 1044 // Promote immediates. 1045 for (unsigned i = 0, n = In->getNumOperands(); i != n; ++i) { 1046 if (ConstantInt *CI = dyn_cast<ConstantInt>(In->getOperand(i))) 1047 if (CI->getType()->getBitWidth() < DestBW) 1048 In->setOperand(i, ConstantInt::get(DestTy, CI->getZExtValue())); 1049 } 1050 } 1051 1052 bool PolynomialMultiplyRecognize::promoteTypes(BasicBlock *LoopB, 1053 BasicBlock *ExitB) { 1054 assert(LoopB); 1055 // Skip loops where the exit block has more than one predecessor. The values 1056 // coming from the loop block will be promoted to another type, and so the 1057 // values coming into the exit block from other predecessors would also have 1058 // to be promoted. 1059 if (!ExitB || (ExitB->getSinglePredecessor() != LoopB)) 1060 return false; 1061 IntegerType *DestTy = getPmpyType(); 1062 // Check if the exit values have types that are no wider than the type 1063 // that we want to promote to. 1064 unsigned DestBW = DestTy->getBitWidth(); 1065 for (PHINode &P : ExitB->phis()) { 1066 if (P.getNumIncomingValues() != 1) 1067 return false; 1068 assert(P.getIncomingBlock(0) == LoopB); 1069 IntegerType *T = dyn_cast<IntegerType>(P.getType()); 1070 if (!T || T->getBitWidth() > DestBW) 1071 return false; 1072 } 1073 1074 // Check all instructions in the loop. 1075 for (Instruction &In : *LoopB) 1076 if (!In.isTerminator() && !isPromotableTo(&In, DestTy)) 1077 return false; 1078 1079 // Perform the promotion. 1080 std::vector<Instruction*> LoopIns; 1081 std::transform(LoopB->begin(), LoopB->end(), std::back_inserter(LoopIns), 1082 [](Instruction &In) { return &In; }); 1083 for (Instruction *In : LoopIns) 1084 if (!In->isTerminator()) 1085 promoteTo(In, DestTy, LoopB); 1086 1087 // Fix up the PHI nodes in the exit block. 1088 Instruction *EndI = ExitB->getFirstNonPHI(); 1089 BasicBlock::iterator End = EndI ? EndI->getIterator() : ExitB->end(); 1090 for (auto I = ExitB->begin(); I != End; ++I) { 1091 PHINode *P = dyn_cast<PHINode>(I); 1092 if (!P) 1093 break; 1094 Type *Ty0 = P->getIncomingValue(0)->getType(); 1095 Type *PTy = P->getType(); 1096 if (PTy != Ty0) { 1097 assert(Ty0 == DestTy); 1098 // In order to create the trunc, P must have the promoted type. 1099 P->mutateType(Ty0); 1100 Value *T = IRBuilder<>(ExitB, End).CreateTrunc(P, PTy); 1101 // In order for the RAUW to work, the types of P and T must match. 1102 P->mutateType(PTy); 1103 P->replaceAllUsesWith(T); 1104 // Final update of the P's type. 1105 P->mutateType(Ty0); 1106 cast<Instruction>(T)->setOperand(0, P); 1107 } 1108 } 1109 1110 return true; 1111 } 1112 1113 bool PolynomialMultiplyRecognize::findCycle(Value *Out, Value *In, 1114 ValueSeq &Cycle) { 1115 // Out = ..., In, ... 1116 if (Out == In) 1117 return true; 1118 1119 auto *BB = cast<Instruction>(Out)->getParent(); 1120 bool HadPhi = false; 1121 1122 for (auto U : Out->users()) { 1123 auto *I = dyn_cast<Instruction>(&*U); 1124 if (I == nullptr || I->getParent() != BB) 1125 continue; 1126 // Make sure that there are no multi-iteration cycles, e.g. 1127 // p1 = phi(p2) 1128 // p2 = phi(p1) 1129 // The cycle p1->p2->p1 would span two loop iterations. 1130 // Check that there is only one phi in the cycle. 1131 bool IsPhi = isa<PHINode>(I); 1132 if (IsPhi && HadPhi) 1133 return false; 1134 HadPhi |= IsPhi; 1135 if (Cycle.count(I)) 1136 return false; 1137 Cycle.insert(I); 1138 if (findCycle(I, In, Cycle)) 1139 break; 1140 Cycle.remove(I); 1141 } 1142 return !Cycle.empty(); 1143 } 1144 1145 void PolynomialMultiplyRecognize::classifyCycle(Instruction *DivI, 1146 ValueSeq &Cycle, ValueSeq &Early, ValueSeq &Late) { 1147 // All the values in the cycle that are between the phi node and the 1148 // divider instruction will be classified as "early", all other values 1149 // will be "late". 1150 1151 bool IsE = true; 1152 unsigned I, N = Cycle.size(); 1153 for (I = 0; I < N; ++I) { 1154 Value *V = Cycle[I]; 1155 if (DivI == V) 1156 IsE = false; 1157 else if (!isa<PHINode>(V)) 1158 continue; 1159 // Stop if found either. 1160 break; 1161 } 1162 // "I" is the index of either DivI or the phi node, whichever was first. 1163 // "E" is "false" or "true" respectively. 1164 ValueSeq &First = !IsE ? Early : Late; 1165 for (unsigned J = 0; J < I; ++J) 1166 First.insert(Cycle[J]); 1167 1168 ValueSeq &Second = IsE ? Early : Late; 1169 Second.insert(Cycle[I]); 1170 for (++I; I < N; ++I) { 1171 Value *V = Cycle[I]; 1172 if (DivI == V || isa<PHINode>(V)) 1173 break; 1174 Second.insert(V); 1175 } 1176 1177 for (; I < N; ++I) 1178 First.insert(Cycle[I]); 1179 } 1180 1181 bool PolynomialMultiplyRecognize::classifyInst(Instruction *UseI, 1182 ValueSeq &Early, ValueSeq &Late) { 1183 // Select is an exception, since the condition value does not have to be 1184 // classified in the same way as the true/false values. The true/false 1185 // values do have to be both early or both late. 1186 if (UseI->getOpcode() == Instruction::Select) { 1187 Value *TV = UseI->getOperand(1), *FV = UseI->getOperand(2); 1188 if (Early.count(TV) || Early.count(FV)) { 1189 if (Late.count(TV) || Late.count(FV)) 1190 return false; 1191 Early.insert(UseI); 1192 } else if (Late.count(TV) || Late.count(FV)) { 1193 if (Early.count(TV) || Early.count(FV)) 1194 return false; 1195 Late.insert(UseI); 1196 } 1197 return true; 1198 } 1199 1200 // Not sure what would be the example of this, but the code below relies 1201 // on having at least one operand. 1202 if (UseI->getNumOperands() == 0) 1203 return true; 1204 1205 bool AE = true, AL = true; 1206 for (auto &I : UseI->operands()) { 1207 if (Early.count(&*I)) 1208 AL = false; 1209 else if (Late.count(&*I)) 1210 AE = false; 1211 } 1212 // If the operands appear "all early" and "all late" at the same time, 1213 // then it means that none of them are actually classified as either. 1214 // This is harmless. 1215 if (AE && AL) 1216 return true; 1217 // Conversely, if they are neither "all early" nor "all late", then 1218 // we have a mixture of early and late operands that is not a known 1219 // exception. 1220 if (!AE && !AL) 1221 return false; 1222 1223 // Check that we have covered the two special cases. 1224 assert(AE != AL); 1225 1226 if (AE) 1227 Early.insert(UseI); 1228 else 1229 Late.insert(UseI); 1230 return true; 1231 } 1232 1233 bool PolynomialMultiplyRecognize::commutesWithShift(Instruction *I) { 1234 switch (I->getOpcode()) { 1235 case Instruction::And: 1236 case Instruction::Or: 1237 case Instruction::Xor: 1238 case Instruction::LShr: 1239 case Instruction::Shl: 1240 case Instruction::Select: 1241 case Instruction::ICmp: 1242 case Instruction::PHI: 1243 break; 1244 default: 1245 return false; 1246 } 1247 return true; 1248 } 1249 1250 bool PolynomialMultiplyRecognize::highBitsAreZero(Value *V, 1251 unsigned IterCount) { 1252 auto *T = dyn_cast<IntegerType>(V->getType()); 1253 if (!T) 1254 return false; 1255 1256 KnownBits Known(T->getBitWidth()); 1257 computeKnownBits(V, Known, DL); 1258 return Known.countMinLeadingZeros() >= IterCount; 1259 } 1260 1261 bool PolynomialMultiplyRecognize::keepsHighBitsZero(Value *V, 1262 unsigned IterCount) { 1263 // Assume that all inputs to the value have the high bits zero. 1264 // Check if the value itself preserves the zeros in the high bits. 1265 if (auto *C = dyn_cast<ConstantInt>(V)) 1266 return C->getValue().countLeadingZeros() >= IterCount; 1267 1268 if (auto *I = dyn_cast<Instruction>(V)) { 1269 switch (I->getOpcode()) { 1270 case Instruction::And: 1271 case Instruction::Or: 1272 case Instruction::Xor: 1273 case Instruction::LShr: 1274 case Instruction::Select: 1275 case Instruction::ICmp: 1276 case Instruction::PHI: 1277 case Instruction::ZExt: 1278 return true; 1279 } 1280 } 1281 1282 return false; 1283 } 1284 1285 bool PolynomialMultiplyRecognize::isOperandShifted(Instruction *I, Value *Op) { 1286 unsigned Opc = I->getOpcode(); 1287 if (Opc == Instruction::Shl || Opc == Instruction::LShr) 1288 return Op != I->getOperand(1); 1289 return true; 1290 } 1291 1292 bool PolynomialMultiplyRecognize::convertShiftsToLeft(BasicBlock *LoopB, 1293 BasicBlock *ExitB, unsigned IterCount) { 1294 Value *CIV = getCountIV(LoopB); 1295 if (CIV == nullptr) 1296 return false; 1297 auto *CIVTy = dyn_cast<IntegerType>(CIV->getType()); 1298 if (CIVTy == nullptr) 1299 return false; 1300 1301 ValueSeq RShifts; 1302 ValueSeq Early, Late, Cycled; 1303 1304 // Find all value cycles that contain logical right shifts by 1. 1305 for (Instruction &I : *LoopB) { 1306 using namespace PatternMatch; 1307 1308 Value *V = nullptr; 1309 if (!match(&I, m_LShr(m_Value(V), m_One()))) 1310 continue; 1311 ValueSeq C; 1312 if (!findCycle(&I, V, C)) 1313 continue; 1314 1315 // Found a cycle. 1316 C.insert(&I); 1317 classifyCycle(&I, C, Early, Late); 1318 Cycled.insert(C.begin(), C.end()); 1319 RShifts.insert(&I); 1320 } 1321 1322 // Find the set of all values affected by the shift cycles, i.e. all 1323 // cycled values, and (recursively) all their users. 1324 ValueSeq Users(Cycled.begin(), Cycled.end()); 1325 for (unsigned i = 0; i < Users.size(); ++i) { 1326 Value *V = Users[i]; 1327 if (!isa<IntegerType>(V->getType())) 1328 return false; 1329 auto *R = cast<Instruction>(V); 1330 // If the instruction does not commute with shifts, the loop cannot 1331 // be unshifted. 1332 if (!commutesWithShift(R)) 1333 return false; 1334 for (auto I = R->user_begin(), E = R->user_end(); I != E; ++I) { 1335 auto *T = cast<Instruction>(*I); 1336 // Skip users from outside of the loop. They will be handled later. 1337 // Also, skip the right-shifts and phi nodes, since they mix early 1338 // and late values. 1339 if (T->getParent() != LoopB || RShifts.count(T) || isa<PHINode>(T)) 1340 continue; 1341 1342 Users.insert(T); 1343 if (!classifyInst(T, Early, Late)) 1344 return false; 1345 } 1346 } 1347 1348 if (Users.empty()) 1349 return false; 1350 1351 // Verify that high bits remain zero. 1352 ValueSeq Internal(Users.begin(), Users.end()); 1353 ValueSeq Inputs; 1354 for (unsigned i = 0; i < Internal.size(); ++i) { 1355 auto *R = dyn_cast<Instruction>(Internal[i]); 1356 if (!R) 1357 continue; 1358 for (Value *Op : R->operands()) { 1359 auto *T = dyn_cast<Instruction>(Op); 1360 if (T && T->getParent() != LoopB) 1361 Inputs.insert(Op); 1362 else 1363 Internal.insert(Op); 1364 } 1365 } 1366 for (Value *V : Inputs) 1367 if (!highBitsAreZero(V, IterCount)) 1368 return false; 1369 for (Value *V : Internal) 1370 if (!keepsHighBitsZero(V, IterCount)) 1371 return false; 1372 1373 // Finally, the work can be done. Unshift each user. 1374 IRBuilder<> IRB(LoopB); 1375 std::map<Value*,Value*> ShiftMap; 1376 1377 using CastMapType = std::map<std::pair<Value *, Type *>, Value *>; 1378 1379 CastMapType CastMap; 1380 1381 auto upcast = [] (CastMapType &CM, IRBuilder<> &IRB, Value *V, 1382 IntegerType *Ty) -> Value* { 1383 auto H = CM.find(std::make_pair(V, Ty)); 1384 if (H != CM.end()) 1385 return H->second; 1386 Value *CV = IRB.CreateIntCast(V, Ty, false); 1387 CM.insert(std::make_pair(std::make_pair(V, Ty), CV)); 1388 return CV; 1389 }; 1390 1391 for (auto I = LoopB->begin(), E = LoopB->end(); I != E; ++I) { 1392 using namespace PatternMatch; 1393 1394 if (isa<PHINode>(I) || !Users.count(&*I)) 1395 continue; 1396 1397 // Match lshr x, 1. 1398 Value *V = nullptr; 1399 if (match(&*I, m_LShr(m_Value(V), m_One()))) { 1400 replaceAllUsesOfWithIn(&*I, V, LoopB); 1401 continue; 1402 } 1403 // For each non-cycled operand, replace it with the corresponding 1404 // value shifted left. 1405 for (auto &J : I->operands()) { 1406 Value *Op = J.get(); 1407 if (!isOperandShifted(&*I, Op)) 1408 continue; 1409 if (Users.count(Op)) 1410 continue; 1411 // Skip shifting zeros. 1412 if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero()) 1413 continue; 1414 // Check if we have already generated a shift for this value. 1415 auto F = ShiftMap.find(Op); 1416 Value *W = (F != ShiftMap.end()) ? F->second : nullptr; 1417 if (W == nullptr) { 1418 IRB.SetInsertPoint(&*I); 1419 // First, the shift amount will be CIV or CIV+1, depending on 1420 // whether the value is early or late. Instead of creating CIV+1, 1421 // do a single shift of the value. 1422 Value *ShAmt = CIV, *ShVal = Op; 1423 auto *VTy = cast<IntegerType>(ShVal->getType()); 1424 auto *ATy = cast<IntegerType>(ShAmt->getType()); 1425 if (Late.count(&*I)) 1426 ShVal = IRB.CreateShl(Op, ConstantInt::get(VTy, 1)); 1427 // Second, the types of the shifted value and the shift amount 1428 // must match. 1429 if (VTy != ATy) { 1430 if (VTy->getBitWidth() < ATy->getBitWidth()) 1431 ShVal = upcast(CastMap, IRB, ShVal, ATy); 1432 else 1433 ShAmt = upcast(CastMap, IRB, ShAmt, VTy); 1434 } 1435 // Ready to generate the shift and memoize it. 1436 W = IRB.CreateShl(ShVal, ShAmt); 1437 ShiftMap.insert(std::make_pair(Op, W)); 1438 } 1439 I->replaceUsesOfWith(Op, W); 1440 } 1441 } 1442 1443 // Update the users outside of the loop to account for having left 1444 // shifts. They would normally be shifted right in the loop, so shift 1445 // them right after the loop exit. 1446 // Take advantage of the loop-closed SSA form, which has all the post- 1447 // loop values in phi nodes. 1448 IRB.SetInsertPoint(ExitB, ExitB->getFirstInsertionPt()); 1449 for (auto P = ExitB->begin(), Q = ExitB->end(); P != Q; ++P) { 1450 if (!isa<PHINode>(P)) 1451 break; 1452 auto *PN = cast<PHINode>(P); 1453 Value *U = PN->getIncomingValueForBlock(LoopB); 1454 if (!Users.count(U)) 1455 continue; 1456 Value *S = IRB.CreateLShr(PN, ConstantInt::get(PN->getType(), IterCount)); 1457 PN->replaceAllUsesWith(S); 1458 // The above RAUW will create 1459 // S = lshr S, IterCount 1460 // so we need to fix it back into 1461 // S = lshr PN, IterCount 1462 cast<User>(S)->replaceUsesOfWith(S, PN); 1463 } 1464 1465 return true; 1466 } 1467 1468 void PolynomialMultiplyRecognize::cleanupLoopBody(BasicBlock *LoopB) { 1469 for (auto &I : *LoopB) 1470 if (Value *SV = SimplifyInstruction(&I, {DL, &TLI, &DT})) 1471 I.replaceAllUsesWith(SV); 1472 1473 for (auto I = LoopB->begin(), N = I; I != LoopB->end(); I = N) { 1474 N = std::next(I); 1475 RecursivelyDeleteTriviallyDeadInstructions(&*I, &TLI); 1476 } 1477 } 1478 1479 unsigned PolynomialMultiplyRecognize::getInverseMxN(unsigned QP) { 1480 // Arrays of coefficients of Q and the inverse, C. 1481 // Q[i] = coefficient at x^i. 1482 std::array<char,32> Q, C; 1483 1484 for (unsigned i = 0; i < 32; ++i) { 1485 Q[i] = QP & 1; 1486 QP >>= 1; 1487 } 1488 assert(Q[0] == 1); 1489 1490 // Find C, such that 1491 // (Q[n]*x^n + ... + Q[1]*x + Q[0]) * (C[n]*x^n + ... + C[1]*x + C[0]) = 1 1492 // 1493 // For it to have a solution, Q[0] must be 1. Since this is Z2[x], the 1494 // operations * and + are & and ^ respectively. 1495 // 1496 // Find C[i] recursively, by comparing i-th coefficient in the product 1497 // with 0 (or 1 for i=0). 1498 // 1499 // C[0] = 1, since C[0] = Q[0], and Q[0] = 1. 1500 C[0] = 1; 1501 for (unsigned i = 1; i < 32; ++i) { 1502 // Solve for C[i] in: 1503 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i]Q[0] = 0 1504 // This is equivalent to 1505 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i] = 0 1506 // which is 1507 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] = C[i] 1508 unsigned T = 0; 1509 for (unsigned j = 0; j < i; ++j) 1510 T = T ^ (C[j] & Q[i-j]); 1511 C[i] = T; 1512 } 1513 1514 unsigned QV = 0; 1515 for (unsigned i = 0; i < 32; ++i) 1516 if (C[i]) 1517 QV |= (1 << i); 1518 1519 return QV; 1520 } 1521 1522 Value *PolynomialMultiplyRecognize::generate(BasicBlock::iterator At, 1523 ParsedValues &PV) { 1524 IRBuilder<> B(&*At); 1525 Module *M = At->getParent()->getParent()->getParent(); 1526 Function *PMF = Intrinsic::getDeclaration(M, Intrinsic::hexagon_M4_pmpyw); 1527 1528 Value *P = PV.P, *Q = PV.Q, *P0 = P; 1529 unsigned IC = PV.IterCount; 1530 1531 if (PV.M != nullptr) 1532 P0 = P = B.CreateXor(P, PV.M); 1533 1534 // Create a bit mask to clear the high bits beyond IterCount. 1535 auto *BMI = ConstantInt::get(P->getType(), APInt::getLowBitsSet(32, IC)); 1536 1537 if (PV.IterCount != 32) 1538 P = B.CreateAnd(P, BMI); 1539 1540 if (PV.Inv) { 1541 auto *QI = dyn_cast<ConstantInt>(PV.Q); 1542 assert(QI && QI->getBitWidth() <= 32); 1543 1544 // Again, clearing bits beyond IterCount. 1545 unsigned M = (1 << PV.IterCount) - 1; 1546 unsigned Tmp = (QI->getZExtValue() | 1) & M; 1547 unsigned QV = getInverseMxN(Tmp) & M; 1548 auto *QVI = ConstantInt::get(QI->getType(), QV); 1549 P = B.CreateCall(PMF, {P, QVI}); 1550 P = B.CreateTrunc(P, QI->getType()); 1551 if (IC != 32) 1552 P = B.CreateAnd(P, BMI); 1553 } 1554 1555 Value *R = B.CreateCall(PMF, {P, Q}); 1556 1557 if (PV.M != nullptr) 1558 R = B.CreateXor(R, B.CreateIntCast(P0, R->getType(), false)); 1559 1560 return R; 1561 } 1562 1563 static bool hasZeroSignBit(const Value *V) { 1564 if (const auto *CI = dyn_cast<const ConstantInt>(V)) 1565 return (CI->getType()->getSignBit() & CI->getSExtValue()) == 0; 1566 const Instruction *I = dyn_cast<const Instruction>(V); 1567 if (!I) 1568 return false; 1569 switch (I->getOpcode()) { 1570 case Instruction::LShr: 1571 if (const auto SI = dyn_cast<const ConstantInt>(I->getOperand(1))) 1572 return SI->getZExtValue() > 0; 1573 return false; 1574 case Instruction::Or: 1575 case Instruction::Xor: 1576 return hasZeroSignBit(I->getOperand(0)) && 1577 hasZeroSignBit(I->getOperand(1)); 1578 case Instruction::And: 1579 return hasZeroSignBit(I->getOperand(0)) || 1580 hasZeroSignBit(I->getOperand(1)); 1581 } 1582 return false; 1583 } 1584 1585 void PolynomialMultiplyRecognize::setupPreSimplifier(Simplifier &S) { 1586 S.addRule("sink-zext", 1587 // Sink zext past bitwise operations. 1588 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1589 if (I->getOpcode() != Instruction::ZExt) 1590 return nullptr; 1591 Instruction *T = dyn_cast<Instruction>(I->getOperand(0)); 1592 if (!T) 1593 return nullptr; 1594 switch (T->getOpcode()) { 1595 case Instruction::And: 1596 case Instruction::Or: 1597 case Instruction::Xor: 1598 break; 1599 default: 1600 return nullptr; 1601 } 1602 IRBuilder<> B(Ctx); 1603 return B.CreateBinOp(cast<BinaryOperator>(T)->getOpcode(), 1604 B.CreateZExt(T->getOperand(0), I->getType()), 1605 B.CreateZExt(T->getOperand(1), I->getType())); 1606 }); 1607 S.addRule("xor/and -> and/xor", 1608 // (xor (and x a) (and y a)) -> (and (xor x y) a) 1609 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1610 if (I->getOpcode() != Instruction::Xor) 1611 return nullptr; 1612 Instruction *And0 = dyn_cast<Instruction>(I->getOperand(0)); 1613 Instruction *And1 = dyn_cast<Instruction>(I->getOperand(1)); 1614 if (!And0 || !And1) 1615 return nullptr; 1616 if (And0->getOpcode() != Instruction::And || 1617 And1->getOpcode() != Instruction::And) 1618 return nullptr; 1619 if (And0->getOperand(1) != And1->getOperand(1)) 1620 return nullptr; 1621 IRBuilder<> B(Ctx); 1622 return B.CreateAnd(B.CreateXor(And0->getOperand(0), And1->getOperand(0)), 1623 And0->getOperand(1)); 1624 }); 1625 S.addRule("sink binop into select", 1626 // (Op (select c x y) z) -> (select c (Op x z) (Op y z)) 1627 // (Op x (select c y z)) -> (select c (Op x y) (Op x z)) 1628 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1629 BinaryOperator *BO = dyn_cast<BinaryOperator>(I); 1630 if (!BO) 1631 return nullptr; 1632 Instruction::BinaryOps Op = BO->getOpcode(); 1633 if (SelectInst *Sel = dyn_cast<SelectInst>(BO->getOperand(0))) { 1634 IRBuilder<> B(Ctx); 1635 Value *X = Sel->getTrueValue(), *Y = Sel->getFalseValue(); 1636 Value *Z = BO->getOperand(1); 1637 return B.CreateSelect(Sel->getCondition(), 1638 B.CreateBinOp(Op, X, Z), 1639 B.CreateBinOp(Op, Y, Z)); 1640 } 1641 if (SelectInst *Sel = dyn_cast<SelectInst>(BO->getOperand(1))) { 1642 IRBuilder<> B(Ctx); 1643 Value *X = BO->getOperand(0); 1644 Value *Y = Sel->getTrueValue(), *Z = Sel->getFalseValue(); 1645 return B.CreateSelect(Sel->getCondition(), 1646 B.CreateBinOp(Op, X, Y), 1647 B.CreateBinOp(Op, X, Z)); 1648 } 1649 return nullptr; 1650 }); 1651 S.addRule("fold select-select", 1652 // (select c (select c x y) z) -> (select c x z) 1653 // (select c x (select c y z)) -> (select c x z) 1654 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1655 SelectInst *Sel = dyn_cast<SelectInst>(I); 1656 if (!Sel) 1657 return nullptr; 1658 IRBuilder<> B(Ctx); 1659 Value *C = Sel->getCondition(); 1660 if (SelectInst *Sel0 = dyn_cast<SelectInst>(Sel->getTrueValue())) { 1661 if (Sel0->getCondition() == C) 1662 return B.CreateSelect(C, Sel0->getTrueValue(), Sel->getFalseValue()); 1663 } 1664 if (SelectInst *Sel1 = dyn_cast<SelectInst>(Sel->getFalseValue())) { 1665 if (Sel1->getCondition() == C) 1666 return B.CreateSelect(C, Sel->getTrueValue(), Sel1->getFalseValue()); 1667 } 1668 return nullptr; 1669 }); 1670 S.addRule("or-signbit -> xor-signbit", 1671 // (or (lshr x 1) 0x800.0) -> (xor (lshr x 1) 0x800.0) 1672 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1673 if (I->getOpcode() != Instruction::Or) 1674 return nullptr; 1675 ConstantInt *Msb = dyn_cast<ConstantInt>(I->getOperand(1)); 1676 if (!Msb || Msb->getZExtValue() != Msb->getType()->getSignBit()) 1677 return nullptr; 1678 if (!hasZeroSignBit(I->getOperand(0))) 1679 return nullptr; 1680 return IRBuilder<>(Ctx).CreateXor(I->getOperand(0), Msb); 1681 }); 1682 S.addRule("sink lshr into binop", 1683 // (lshr (BitOp x y) c) -> (BitOp (lshr x c) (lshr y c)) 1684 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1685 if (I->getOpcode() != Instruction::LShr) 1686 return nullptr; 1687 BinaryOperator *BitOp = dyn_cast<BinaryOperator>(I->getOperand(0)); 1688 if (!BitOp) 1689 return nullptr; 1690 switch (BitOp->getOpcode()) { 1691 case Instruction::And: 1692 case Instruction::Or: 1693 case Instruction::Xor: 1694 break; 1695 default: 1696 return nullptr; 1697 } 1698 IRBuilder<> B(Ctx); 1699 Value *S = I->getOperand(1); 1700 return B.CreateBinOp(BitOp->getOpcode(), 1701 B.CreateLShr(BitOp->getOperand(0), S), 1702 B.CreateLShr(BitOp->getOperand(1), S)); 1703 }); 1704 S.addRule("expose bitop-const", 1705 // (BitOp1 (BitOp2 x a) b) -> (BitOp2 x (BitOp1 a b)) 1706 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1707 auto IsBitOp = [](unsigned Op) -> bool { 1708 switch (Op) { 1709 case Instruction::And: 1710 case Instruction::Or: 1711 case Instruction::Xor: 1712 return true; 1713 } 1714 return false; 1715 }; 1716 BinaryOperator *BitOp1 = dyn_cast<BinaryOperator>(I); 1717 if (!BitOp1 || !IsBitOp(BitOp1->getOpcode())) 1718 return nullptr; 1719 BinaryOperator *BitOp2 = dyn_cast<BinaryOperator>(BitOp1->getOperand(0)); 1720 if (!BitOp2 || !IsBitOp(BitOp2->getOpcode())) 1721 return nullptr; 1722 ConstantInt *CA = dyn_cast<ConstantInt>(BitOp2->getOperand(1)); 1723 ConstantInt *CB = dyn_cast<ConstantInt>(BitOp1->getOperand(1)); 1724 if (!CA || !CB) 1725 return nullptr; 1726 IRBuilder<> B(Ctx); 1727 Value *X = BitOp2->getOperand(0); 1728 return B.CreateBinOp(BitOp2->getOpcode(), X, 1729 B.CreateBinOp(BitOp1->getOpcode(), CA, CB)); 1730 }); 1731 } 1732 1733 void PolynomialMultiplyRecognize::setupPostSimplifier(Simplifier &S) { 1734 S.addRule("(and (xor (and x a) y) b) -> (and (xor x y) b), if b == b&a", 1735 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1736 if (I->getOpcode() != Instruction::And) 1737 return nullptr; 1738 Instruction *Xor = dyn_cast<Instruction>(I->getOperand(0)); 1739 ConstantInt *C0 = dyn_cast<ConstantInt>(I->getOperand(1)); 1740 if (!Xor || !C0) 1741 return nullptr; 1742 if (Xor->getOpcode() != Instruction::Xor) 1743 return nullptr; 1744 Instruction *And0 = dyn_cast<Instruction>(Xor->getOperand(0)); 1745 Instruction *And1 = dyn_cast<Instruction>(Xor->getOperand(1)); 1746 // Pick the first non-null and. 1747 if (!And0 || And0->getOpcode() != Instruction::And) 1748 std::swap(And0, And1); 1749 ConstantInt *C1 = dyn_cast<ConstantInt>(And0->getOperand(1)); 1750 if (!C1) 1751 return nullptr; 1752 uint32_t V0 = C0->getZExtValue(); 1753 uint32_t V1 = C1->getZExtValue(); 1754 if (V0 != (V0 & V1)) 1755 return nullptr; 1756 IRBuilder<> B(Ctx); 1757 return B.CreateAnd(B.CreateXor(And0->getOperand(0), And1), C0); 1758 }); 1759 } 1760 1761 bool PolynomialMultiplyRecognize::recognize() { 1762 LLVM_DEBUG(dbgs() << "Starting PolynomialMultiplyRecognize on loop\n" 1763 << *CurLoop << '\n'); 1764 // Restrictions: 1765 // - The loop must consist of a single block. 1766 // - The iteration count must be known at compile-time. 1767 // - The loop must have an induction variable starting from 0, and 1768 // incremented in each iteration of the loop. 1769 BasicBlock *LoopB = CurLoop->getHeader(); 1770 LLVM_DEBUG(dbgs() << "Loop header:\n" << *LoopB); 1771 1772 if (LoopB != CurLoop->getLoopLatch()) 1773 return false; 1774 BasicBlock *ExitB = CurLoop->getExitBlock(); 1775 if (ExitB == nullptr) 1776 return false; 1777 BasicBlock *EntryB = CurLoop->getLoopPreheader(); 1778 if (EntryB == nullptr) 1779 return false; 1780 1781 unsigned IterCount = 0; 1782 const SCEV *CT = SE.getBackedgeTakenCount(CurLoop); 1783 if (isa<SCEVCouldNotCompute>(CT)) 1784 return false; 1785 if (auto *CV = dyn_cast<SCEVConstant>(CT)) 1786 IterCount = CV->getValue()->getZExtValue() + 1; 1787 1788 Value *CIV = getCountIV(LoopB); 1789 ParsedValues PV; 1790 Simplifier PreSimp; 1791 PV.IterCount = IterCount; 1792 LLVM_DEBUG(dbgs() << "Loop IV: " << *CIV << "\nIterCount: " << IterCount 1793 << '\n'); 1794 1795 setupPreSimplifier(PreSimp); 1796 1797 // Perform a preliminary scan of select instructions to see if any of them 1798 // looks like a generator of the polynomial multiply steps. Assume that a 1799 // loop can only contain a single transformable operation, so stop the 1800 // traversal after the first reasonable candidate was found. 1801 // XXX: Currently this approach can modify the loop before being 100% sure 1802 // that the transformation can be carried out. 1803 bool FoundPreScan = false; 1804 auto FeedsPHI = [LoopB](const Value *V) -> bool { 1805 for (const Value *U : V->users()) { 1806 if (const auto *P = dyn_cast<const PHINode>(U)) 1807 if (P->getParent() == LoopB) 1808 return true; 1809 } 1810 return false; 1811 }; 1812 for (Instruction &In : *LoopB) { 1813 SelectInst *SI = dyn_cast<SelectInst>(&In); 1814 if (!SI || !FeedsPHI(SI)) 1815 continue; 1816 1817 Simplifier::Context C(SI); 1818 Value *T = PreSimp.simplify(C); 1819 SelectInst *SelI = (T && isa<SelectInst>(T)) ? cast<SelectInst>(T) : SI; 1820 LLVM_DEBUG(dbgs() << "scanSelect(pre-scan): " << PE(C, SelI) << '\n'); 1821 if (scanSelect(SelI, LoopB, EntryB, CIV, PV, true)) { 1822 FoundPreScan = true; 1823 if (SelI != SI) { 1824 Value *NewSel = C.materialize(LoopB, SI->getIterator()); 1825 SI->replaceAllUsesWith(NewSel); 1826 RecursivelyDeleteTriviallyDeadInstructions(SI, &TLI); 1827 } 1828 break; 1829 } 1830 } 1831 1832 if (!FoundPreScan) { 1833 LLVM_DEBUG(dbgs() << "Have not found candidates for pmpy\n"); 1834 return false; 1835 } 1836 1837 if (!PV.Left) { 1838 // The right shift version actually only returns the higher bits of 1839 // the result (each iteration discards the LSB). If we want to convert it 1840 // to a left-shifting loop, the working data type must be at least as 1841 // wide as the target's pmpy instruction. 1842 if (!promoteTypes(LoopB, ExitB)) 1843 return false; 1844 // Run post-promotion simplifications. 1845 Simplifier PostSimp; 1846 setupPostSimplifier(PostSimp); 1847 for (Instruction &In : *LoopB) { 1848 SelectInst *SI = dyn_cast<SelectInst>(&In); 1849 if (!SI || !FeedsPHI(SI)) 1850 continue; 1851 Simplifier::Context C(SI); 1852 Value *T = PostSimp.simplify(C); 1853 SelectInst *SelI = dyn_cast_or_null<SelectInst>(T); 1854 if (SelI != SI) { 1855 Value *NewSel = C.materialize(LoopB, SI->getIterator()); 1856 SI->replaceAllUsesWith(NewSel); 1857 RecursivelyDeleteTriviallyDeadInstructions(SI, &TLI); 1858 } 1859 break; 1860 } 1861 1862 if (!convertShiftsToLeft(LoopB, ExitB, IterCount)) 1863 return false; 1864 cleanupLoopBody(LoopB); 1865 } 1866 1867 // Scan the loop again, find the generating select instruction. 1868 bool FoundScan = false; 1869 for (Instruction &In : *LoopB) { 1870 SelectInst *SelI = dyn_cast<SelectInst>(&In); 1871 if (!SelI) 1872 continue; 1873 LLVM_DEBUG(dbgs() << "scanSelect: " << *SelI << '\n'); 1874 FoundScan = scanSelect(SelI, LoopB, EntryB, CIV, PV, false); 1875 if (FoundScan) 1876 break; 1877 } 1878 assert(FoundScan); 1879 1880 LLVM_DEBUG({ 1881 StringRef PP = (PV.M ? "(P+M)" : "P"); 1882 if (!PV.Inv) 1883 dbgs() << "Found pmpy idiom: R = " << PP << ".Q\n"; 1884 else 1885 dbgs() << "Found inverse pmpy idiom: R = (" << PP << "/Q).Q) + " 1886 << PP << "\n"; 1887 dbgs() << " Res:" << *PV.Res << "\n P:" << *PV.P << "\n"; 1888 if (PV.M) 1889 dbgs() << " M:" << *PV.M << "\n"; 1890 dbgs() << " Q:" << *PV.Q << "\n"; 1891 dbgs() << " Iteration count:" << PV.IterCount << "\n"; 1892 }); 1893 1894 BasicBlock::iterator At(EntryB->getTerminator()); 1895 Value *PM = generate(At, PV); 1896 if (PM == nullptr) 1897 return false; 1898 1899 if (PM->getType() != PV.Res->getType()) 1900 PM = IRBuilder<>(&*At).CreateIntCast(PM, PV.Res->getType(), false); 1901 1902 PV.Res->replaceAllUsesWith(PM); 1903 PV.Res->eraseFromParent(); 1904 return true; 1905 } 1906 1907 int HexagonLoopIdiomRecognize::getSCEVStride(const SCEVAddRecExpr *S) { 1908 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getOperand(1))) 1909 return SC->getAPInt().getSExtValue(); 1910 return 0; 1911 } 1912 1913 bool HexagonLoopIdiomRecognize::isLegalStore(Loop *CurLoop, StoreInst *SI) { 1914 // Allow volatile stores if HexagonVolatileMemcpy is enabled. 1915 if (!(SI->isVolatile() && HexagonVolatileMemcpy) && !SI->isSimple()) 1916 return false; 1917 1918 Value *StoredVal = SI->getValueOperand(); 1919 Value *StorePtr = SI->getPointerOperand(); 1920 1921 // Reject stores that are so large that they overflow an unsigned. 1922 uint64_t SizeInBits = DL->getTypeSizeInBits(StoredVal->getType()); 1923 if ((SizeInBits & 7) || (SizeInBits >> 32) != 0) 1924 return false; 1925 1926 // See if the pointer expression is an AddRec like {base,+,1} on the current 1927 // loop, which indicates a strided store. If we have something else, it's a 1928 // random store we can't handle. 1929 auto *StoreEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); 1930 if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine()) 1931 return false; 1932 1933 // Check to see if the stride matches the size of the store. If so, then we 1934 // know that every byte is touched in the loop. 1935 int Stride = getSCEVStride(StoreEv); 1936 if (Stride == 0) 1937 return false; 1938 unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType()); 1939 if (StoreSize != unsigned(std::abs(Stride))) 1940 return false; 1941 1942 // The store must be feeding a non-volatile load. 1943 LoadInst *LI = dyn_cast<LoadInst>(SI->getValueOperand()); 1944 if (!LI || !LI->isSimple()) 1945 return false; 1946 1947 // See if the pointer expression is an AddRec like {base,+,1} on the current 1948 // loop, which indicates a strided load. If we have something else, it's a 1949 // random load we can't handle. 1950 Value *LoadPtr = LI->getPointerOperand(); 1951 auto *LoadEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(LoadPtr)); 1952 if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine()) 1953 return false; 1954 1955 // The store and load must share the same stride. 1956 if (StoreEv->getOperand(1) != LoadEv->getOperand(1)) 1957 return false; 1958 1959 // Success. This store can be converted into a memcpy. 1960 return true; 1961 } 1962 1963 /// mayLoopAccessLocation - Return true if the specified loop might access the 1964 /// specified pointer location, which is a loop-strided access. The 'Access' 1965 /// argument specifies what the verboten forms of access are (read or write). 1966 static bool 1967 mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, 1968 const SCEV *BECount, unsigned StoreSize, 1969 AliasAnalysis &AA, 1970 SmallPtrSetImpl<Instruction *> &Ignored) { 1971 // Get the location that may be stored across the loop. Since the access 1972 // is strided positively through memory, we say that the modified location 1973 // starts at the pointer and has infinite size. 1974 LocationSize AccessSize = LocationSize::unknown(); 1975 1976 // If the loop iterates a fixed number of times, we can refine the access 1977 // size to be exactly the size of the memset, which is (BECount+1)*StoreSize 1978 if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount)) 1979 AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) * 1980 StoreSize); 1981 1982 // TODO: For this to be really effective, we have to dive into the pointer 1983 // operand in the store. Store to &A[i] of 100 will always return may alias 1984 // with store of &A[100], we need to StoreLoc to be "A" with size of 100, 1985 // which will then no-alias a store to &A[100]. 1986 MemoryLocation StoreLoc(Ptr, AccessSize); 1987 1988 for (auto *B : L->blocks()) 1989 for (auto &I : *B) 1990 if (Ignored.count(&I) == 0 && 1991 isModOrRefSet( 1992 intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access))) 1993 return true; 1994 1995 return false; 1996 } 1997 1998 void HexagonLoopIdiomRecognize::collectStores(Loop *CurLoop, BasicBlock *BB, 1999 SmallVectorImpl<StoreInst*> &Stores) { 2000 Stores.clear(); 2001 for (Instruction &I : *BB) 2002 if (StoreInst *SI = dyn_cast<StoreInst>(&I)) 2003 if (isLegalStore(CurLoop, SI)) 2004 Stores.push_back(SI); 2005 } 2006 2007 bool HexagonLoopIdiomRecognize::processCopyingStore(Loop *CurLoop, 2008 StoreInst *SI, const SCEV *BECount) { 2009 assert((SI->isSimple() || (SI->isVolatile() && HexagonVolatileMemcpy)) && 2010 "Expected only non-volatile stores, or Hexagon-specific memcpy" 2011 "to volatile destination."); 2012 2013 Value *StorePtr = SI->getPointerOperand(); 2014 auto *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); 2015 unsigned Stride = getSCEVStride(StoreEv); 2016 unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType()); 2017 if (Stride != StoreSize) 2018 return false; 2019 2020 // See if the pointer expression is an AddRec like {base,+,1} on the current 2021 // loop, which indicates a strided load. If we have something else, it's a 2022 // random load we can't handle. 2023 LoadInst *LI = dyn_cast<LoadInst>(SI->getValueOperand()); 2024 auto *LoadEv = cast<SCEVAddRecExpr>(SE->getSCEV(LI->getPointerOperand())); 2025 2026 // The trip count of the loop and the base pointer of the addrec SCEV is 2027 // guaranteed to be loop invariant, which means that it should dominate the 2028 // header. This allows us to insert code for it in the preheader. 2029 BasicBlock *Preheader = CurLoop->getLoopPreheader(); 2030 Instruction *ExpPt = Preheader->getTerminator(); 2031 IRBuilder<> Builder(ExpPt); 2032 SCEVExpander Expander(*SE, *DL, "hexagon-loop-idiom"); 2033 2034 Type *IntPtrTy = Builder.getIntPtrTy(*DL, SI->getPointerAddressSpace()); 2035 2036 // Okay, we have a strided store "p[i]" of a loaded value. We can turn 2037 // this into a memcpy/memmove in the loop preheader now if we want. However, 2038 // this would be unsafe to do if there is anything else in the loop that may 2039 // read or write the memory region we're storing to. For memcpy, this 2040 // includes the load that feeds the stores. Check for an alias by generating 2041 // the base address and checking everything. 2042 Value *StoreBasePtr = Expander.expandCodeFor(StoreEv->getStart(), 2043 Builder.getInt8PtrTy(SI->getPointerAddressSpace()), ExpPt); 2044 Value *LoadBasePtr = nullptr; 2045 2046 bool Overlap = false; 2047 bool DestVolatile = SI->isVolatile(); 2048 Type *BECountTy = BECount->getType(); 2049 2050 if (DestVolatile) { 2051 // The trip count must fit in i32, since it is the type of the "num_words" 2052 // argument to hexagon_memcpy_forward_vp4cp4n2. 2053 if (StoreSize != 4 || DL->getTypeSizeInBits(BECountTy) > 32) { 2054 CleanupAndExit: 2055 // If we generated new code for the base pointer, clean up. 2056 Expander.clear(); 2057 if (StoreBasePtr && (LoadBasePtr != StoreBasePtr)) { 2058 RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI); 2059 StoreBasePtr = nullptr; 2060 } 2061 if (LoadBasePtr) { 2062 RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr, TLI); 2063 LoadBasePtr = nullptr; 2064 } 2065 return false; 2066 } 2067 } 2068 2069 SmallPtrSet<Instruction*, 2> Ignore1; 2070 Ignore1.insert(SI); 2071 if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount, 2072 StoreSize, *AA, Ignore1)) { 2073 // Check if the load is the offending instruction. 2074 Ignore1.insert(LI); 2075 if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, 2076 BECount, StoreSize, *AA, Ignore1)) { 2077 // Still bad. Nothing we can do. 2078 goto CleanupAndExit; 2079 } 2080 // It worked with the load ignored. 2081 Overlap = true; 2082 } 2083 2084 if (!Overlap) { 2085 if (DisableMemcpyIdiom || !HasMemcpy) 2086 goto CleanupAndExit; 2087 } else { 2088 // Don't generate memmove if this function will be inlined. This is 2089 // because the caller will undergo this transformation after inlining. 2090 Function *Func = CurLoop->getHeader()->getParent(); 2091 if (Func->hasFnAttribute(Attribute::AlwaysInline)) 2092 goto CleanupAndExit; 2093 2094 // In case of a memmove, the call to memmove will be executed instead 2095 // of the loop, so we need to make sure that there is nothing else in 2096 // the loop than the load, store and instructions that these two depend 2097 // on. 2098 SmallVector<Instruction*,2> Insts; 2099 Insts.push_back(SI); 2100 Insts.push_back(LI); 2101 if (!coverLoop(CurLoop, Insts)) 2102 goto CleanupAndExit; 2103 2104 if (DisableMemmoveIdiom || !HasMemmove) 2105 goto CleanupAndExit; 2106 bool IsNested = CurLoop->getParentLoop() != nullptr; 2107 if (IsNested && OnlyNonNestedMemmove) 2108 goto CleanupAndExit; 2109 } 2110 2111 // For a memcpy, we have to make sure that the input array is not being 2112 // mutated by the loop. 2113 LoadBasePtr = Expander.expandCodeFor(LoadEv->getStart(), 2114 Builder.getInt8PtrTy(LI->getPointerAddressSpace()), ExpPt); 2115 2116 SmallPtrSet<Instruction*, 2> Ignore2; 2117 Ignore2.insert(SI); 2118 if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount, 2119 StoreSize, *AA, Ignore2)) 2120 goto CleanupAndExit; 2121 2122 // Check the stride. 2123 bool StridePos = getSCEVStride(LoadEv) >= 0; 2124 2125 // Currently, the volatile memcpy only emulates traversing memory forward. 2126 if (!StridePos && DestVolatile) 2127 goto CleanupAndExit; 2128 2129 bool RuntimeCheck = (Overlap || DestVolatile); 2130 2131 BasicBlock *ExitB; 2132 if (RuntimeCheck) { 2133 // The runtime check needs a single exit block. 2134 SmallVector<BasicBlock*, 8> ExitBlocks; 2135 CurLoop->getUniqueExitBlocks(ExitBlocks); 2136 if (ExitBlocks.size() != 1) 2137 goto CleanupAndExit; 2138 ExitB = ExitBlocks[0]; 2139 } 2140 2141 // The # stored bytes is (BECount+1)*Size. Expand the trip count out to 2142 // pointer size if it isn't already. 2143 LLVMContext &Ctx = SI->getContext(); 2144 BECount = SE->getTruncateOrZeroExtend(BECount, IntPtrTy); 2145 DebugLoc DLoc = SI->getDebugLoc(); 2146 2147 const SCEV *NumBytesS = 2148 SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW); 2149 if (StoreSize != 1) 2150 NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize), 2151 SCEV::FlagNUW); 2152 Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntPtrTy, ExpPt); 2153 if (Instruction *In = dyn_cast<Instruction>(NumBytes)) 2154 if (Value *Simp = SimplifyInstruction(In, {*DL, TLI, DT})) 2155 NumBytes = Simp; 2156 2157 CallInst *NewCall; 2158 2159 if (RuntimeCheck) { 2160 unsigned Threshold = RuntimeMemSizeThreshold; 2161 if (ConstantInt *CI = dyn_cast<ConstantInt>(NumBytes)) { 2162 uint64_t C = CI->getZExtValue(); 2163 if (Threshold != 0 && C < Threshold) 2164 goto CleanupAndExit; 2165 if (C < CompileTimeMemSizeThreshold) 2166 goto CleanupAndExit; 2167 } 2168 2169 BasicBlock *Header = CurLoop->getHeader(); 2170 Function *Func = Header->getParent(); 2171 Loop *ParentL = LF->getLoopFor(Preheader); 2172 StringRef HeaderName = Header->getName(); 2173 2174 // Create a new (empty) preheader, and update the PHI nodes in the 2175 // header to use the new preheader. 2176 BasicBlock *NewPreheader = BasicBlock::Create(Ctx, HeaderName+".rtli.ph", 2177 Func, Header); 2178 if (ParentL) 2179 ParentL->addBasicBlockToLoop(NewPreheader, *LF); 2180 IRBuilder<>(NewPreheader).CreateBr(Header); 2181 for (auto &In : *Header) { 2182 PHINode *PN = dyn_cast<PHINode>(&In); 2183 if (!PN) 2184 break; 2185 int bx = PN->getBasicBlockIndex(Preheader); 2186 if (bx >= 0) 2187 PN->setIncomingBlock(bx, NewPreheader); 2188 } 2189 DT->addNewBlock(NewPreheader, Preheader); 2190 DT->changeImmediateDominator(Header, NewPreheader); 2191 2192 // Check for safe conditions to execute memmove. 2193 // If stride is positive, copying things from higher to lower addresses 2194 // is equivalent to memmove. For negative stride, it's the other way 2195 // around. Copying forward in memory with positive stride may not be 2196 // same as memmove since we may be copying values that we just stored 2197 // in some previous iteration. 2198 Value *LA = Builder.CreatePtrToInt(LoadBasePtr, IntPtrTy); 2199 Value *SA = Builder.CreatePtrToInt(StoreBasePtr, IntPtrTy); 2200 Value *LowA = StridePos ? SA : LA; 2201 Value *HighA = StridePos ? LA : SA; 2202 Value *CmpA = Builder.CreateICmpULT(LowA, HighA); 2203 Value *Cond = CmpA; 2204 2205 // Check for distance between pointers. Since the case LowA < HighA 2206 // is checked for above, assume LowA >= HighA. 2207 Value *Dist = Builder.CreateSub(LowA, HighA); 2208 Value *CmpD = Builder.CreateICmpSLE(NumBytes, Dist); 2209 Value *CmpEither = Builder.CreateOr(Cond, CmpD); 2210 Cond = CmpEither; 2211 2212 if (Threshold != 0) { 2213 Type *Ty = NumBytes->getType(); 2214 Value *Thr = ConstantInt::get(Ty, Threshold); 2215 Value *CmpB = Builder.CreateICmpULT(Thr, NumBytes); 2216 Value *CmpBoth = Builder.CreateAnd(Cond, CmpB); 2217 Cond = CmpBoth; 2218 } 2219 BasicBlock *MemmoveB = BasicBlock::Create(Ctx, Header->getName()+".rtli", 2220 Func, NewPreheader); 2221 if (ParentL) 2222 ParentL->addBasicBlockToLoop(MemmoveB, *LF); 2223 Instruction *OldT = Preheader->getTerminator(); 2224 Builder.CreateCondBr(Cond, MemmoveB, NewPreheader); 2225 OldT->eraseFromParent(); 2226 Preheader->setName(Preheader->getName()+".old"); 2227 DT->addNewBlock(MemmoveB, Preheader); 2228 // Find the new immediate dominator of the exit block. 2229 BasicBlock *ExitD = Preheader; 2230 for (auto PI = pred_begin(ExitB), PE = pred_end(ExitB); PI != PE; ++PI) { 2231 BasicBlock *PB = *PI; 2232 ExitD = DT->findNearestCommonDominator(ExitD, PB); 2233 if (!ExitD) 2234 break; 2235 } 2236 // If the prior immediate dominator of ExitB was dominated by the 2237 // old preheader, then the old preheader becomes the new immediate 2238 // dominator. Otherwise don't change anything (because the newly 2239 // added blocks are dominated by the old preheader). 2240 if (ExitD && DT->dominates(Preheader, ExitD)) { 2241 DomTreeNode *BN = DT->getNode(ExitB); 2242 DomTreeNode *DN = DT->getNode(ExitD); 2243 BN->setIDom(DN); 2244 } 2245 2246 // Add a call to memmove to the conditional block. 2247 IRBuilder<> CondBuilder(MemmoveB); 2248 CondBuilder.CreateBr(ExitB); 2249 CondBuilder.SetInsertPoint(MemmoveB->getTerminator()); 2250 2251 if (DestVolatile) { 2252 Type *Int32Ty = Type::getInt32Ty(Ctx); 2253 Type *Int32PtrTy = Type::getInt32PtrTy(Ctx); 2254 Type *VoidTy = Type::getVoidTy(Ctx); 2255 Module *M = Func->getParent(); 2256 FunctionCallee Fn = M->getOrInsertFunction( 2257 HexagonVolatileMemcpyName, VoidTy, Int32PtrTy, Int32PtrTy, Int32Ty); 2258 2259 const SCEV *OneS = SE->getConstant(Int32Ty, 1); 2260 const SCEV *BECount32 = SE->getTruncateOrZeroExtend(BECount, Int32Ty); 2261 const SCEV *NumWordsS = SE->getAddExpr(BECount32, OneS, SCEV::FlagNUW); 2262 Value *NumWords = Expander.expandCodeFor(NumWordsS, Int32Ty, 2263 MemmoveB->getTerminator()); 2264 if (Instruction *In = dyn_cast<Instruction>(NumWords)) 2265 if (Value *Simp = SimplifyInstruction(In, {*DL, TLI, DT})) 2266 NumWords = Simp; 2267 2268 Value *Op0 = (StoreBasePtr->getType() == Int32PtrTy) 2269 ? StoreBasePtr 2270 : CondBuilder.CreateBitCast(StoreBasePtr, Int32PtrTy); 2271 Value *Op1 = (LoadBasePtr->getType() == Int32PtrTy) 2272 ? LoadBasePtr 2273 : CondBuilder.CreateBitCast(LoadBasePtr, Int32PtrTy); 2274 NewCall = CondBuilder.CreateCall(Fn, {Op0, Op1, NumWords}); 2275 } else { 2276 NewCall = CondBuilder.CreateMemMove(StoreBasePtr, SI->getAlignment(), 2277 LoadBasePtr, LI->getAlignment(), 2278 NumBytes); 2279 } 2280 } else { 2281 NewCall = Builder.CreateMemCpy(StoreBasePtr, SI->getAlignment(), 2282 LoadBasePtr, LI->getAlignment(), 2283 NumBytes); 2284 // Okay, the memcpy has been formed. Zap the original store and 2285 // anything that feeds into it. 2286 RecursivelyDeleteTriviallyDeadInstructions(SI, TLI); 2287 } 2288 2289 NewCall->setDebugLoc(DLoc); 2290 2291 LLVM_DEBUG(dbgs() << " Formed " << (Overlap ? "memmove: " : "memcpy: ") 2292 << *NewCall << "\n" 2293 << " from load ptr=" << *LoadEv << " at: " << *LI << "\n" 2294 << " from store ptr=" << *StoreEv << " at: " << *SI 2295 << "\n"); 2296 2297 return true; 2298 } 2299 2300 // Check if the instructions in Insts, together with their dependencies 2301 // cover the loop in the sense that the loop could be safely eliminated once 2302 // the instructions in Insts are removed. 2303 bool HexagonLoopIdiomRecognize::coverLoop(Loop *L, 2304 SmallVectorImpl<Instruction*> &Insts) const { 2305 SmallSet<BasicBlock*,8> LoopBlocks; 2306 for (auto *B : L->blocks()) 2307 LoopBlocks.insert(B); 2308 2309 SetVector<Instruction*> Worklist(Insts.begin(), Insts.end()); 2310 2311 // Collect all instructions from the loop that the instructions in Insts 2312 // depend on (plus their dependencies, etc.). These instructions will 2313 // constitute the expression trees that feed those in Insts, but the trees 2314 // will be limited only to instructions contained in the loop. 2315 for (unsigned i = 0; i < Worklist.size(); ++i) { 2316 Instruction *In = Worklist[i]; 2317 for (auto I = In->op_begin(), E = In->op_end(); I != E; ++I) { 2318 Instruction *OpI = dyn_cast<Instruction>(I); 2319 if (!OpI) 2320 continue; 2321 BasicBlock *PB = OpI->getParent(); 2322 if (!LoopBlocks.count(PB)) 2323 continue; 2324 Worklist.insert(OpI); 2325 } 2326 } 2327 2328 // Scan all instructions in the loop, if any of them have a user outside 2329 // of the loop, or outside of the expressions collected above, then either 2330 // the loop has a side-effect visible outside of it, or there are 2331 // instructions in it that are not involved in the original set Insts. 2332 for (auto *B : L->blocks()) { 2333 for (auto &In : *B) { 2334 if (isa<BranchInst>(In) || isa<DbgInfoIntrinsic>(In)) 2335 continue; 2336 if (!Worklist.count(&In) && In.mayHaveSideEffects()) 2337 return false; 2338 for (const auto &K : In.users()) { 2339 Instruction *UseI = dyn_cast<Instruction>(K); 2340 if (!UseI) 2341 continue; 2342 BasicBlock *UseB = UseI->getParent(); 2343 if (LF->getLoopFor(UseB) != L) 2344 return false; 2345 } 2346 } 2347 } 2348 2349 return true; 2350 } 2351 2352 /// runOnLoopBlock - Process the specified block, which lives in a counted loop 2353 /// with the specified backedge count. This block is known to be in the current 2354 /// loop and not in any subloops. 2355 bool HexagonLoopIdiomRecognize::runOnLoopBlock(Loop *CurLoop, BasicBlock *BB, 2356 const SCEV *BECount, SmallVectorImpl<BasicBlock*> &ExitBlocks) { 2357 // We can only promote stores in this block if they are unconditionally 2358 // executed in the loop. For a block to be unconditionally executed, it has 2359 // to dominate all the exit blocks of the loop. Verify this now. 2360 auto DominatedByBB = [this,BB] (BasicBlock *EB) -> bool { 2361 return DT->dominates(BB, EB); 2362 }; 2363 if (!all_of(ExitBlocks, DominatedByBB)) 2364 return false; 2365 2366 bool MadeChange = false; 2367 // Look for store instructions, which may be optimized to memset/memcpy. 2368 SmallVector<StoreInst*,8> Stores; 2369 collectStores(CurLoop, BB, Stores); 2370 2371 // Optimize the store into a memcpy, if it feeds an similarly strided load. 2372 for (auto &SI : Stores) 2373 MadeChange |= processCopyingStore(CurLoop, SI, BECount); 2374 2375 return MadeChange; 2376 } 2377 2378 bool HexagonLoopIdiomRecognize::runOnCountableLoop(Loop *L) { 2379 PolynomialMultiplyRecognize PMR(L, *DL, *DT, *TLI, *SE); 2380 if (PMR.recognize()) 2381 return true; 2382 2383 if (!HasMemcpy && !HasMemmove) 2384 return false; 2385 2386 const SCEV *BECount = SE->getBackedgeTakenCount(L); 2387 assert(!isa<SCEVCouldNotCompute>(BECount) && 2388 "runOnCountableLoop() called on a loop without a predictable" 2389 "backedge-taken count"); 2390 2391 SmallVector<BasicBlock *, 8> ExitBlocks; 2392 L->getUniqueExitBlocks(ExitBlocks); 2393 2394 bool Changed = false; 2395 2396 // Scan all the blocks in the loop that are not in subloops. 2397 for (auto *BB : L->getBlocks()) { 2398 // Ignore blocks in subloops. 2399 if (LF->getLoopFor(BB) != L) 2400 continue; 2401 Changed |= runOnLoopBlock(L, BB, BECount, ExitBlocks); 2402 } 2403 2404 return Changed; 2405 } 2406 2407 bool HexagonLoopIdiomRecognize::runOnLoop(Loop *L, LPPassManager &LPM) { 2408 const Module &M = *L->getHeader()->getParent()->getParent(); 2409 if (Triple(M.getTargetTriple()).getArch() != Triple::hexagon) 2410 return false; 2411 2412 if (skipLoop(L)) 2413 return false; 2414 2415 // If the loop could not be converted to canonical form, it must have an 2416 // indirectbr in it, just give up. 2417 if (!L->getLoopPreheader()) 2418 return false; 2419 2420 // Disable loop idiom recognition if the function's name is a common idiom. 2421 StringRef Name = L->getHeader()->getParent()->getName(); 2422 if (Name == "memset" || Name == "memcpy" || Name == "memmove") 2423 return false; 2424 2425 AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); 2426 DL = &L->getHeader()->getModule()->getDataLayout(); 2427 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 2428 LF = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 2429 TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); 2430 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 2431 2432 HasMemcpy = TLI->has(LibFunc_memcpy); 2433 HasMemmove = TLI->has(LibFunc_memmove); 2434 2435 if (SE->hasLoopInvariantBackedgeTakenCount(L)) 2436 return runOnCountableLoop(L); 2437 return false; 2438 } 2439 2440 Pass *llvm::createHexagonLoopIdiomPass() { 2441 return new HexagonLoopIdiomRecognize(); 2442 } 2443