1 //===- HexagonLoopIdiomRecognition.cpp ------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/ADT/APInt.h" 10 #include "llvm/ADT/DenseMap.h" 11 #include "llvm/ADT/SetVector.h" 12 #include "llvm/ADT/SmallPtrSet.h" 13 #include "llvm/ADT/SmallSet.h" 14 #include "llvm/ADT/SmallVector.h" 15 #include "llvm/ADT/StringRef.h" 16 #include "llvm/ADT/Triple.h" 17 #include "llvm/Analysis/AliasAnalysis.h" 18 #include "llvm/Analysis/InstructionSimplify.h" 19 #include "llvm/Analysis/LoopInfo.h" 20 #include "llvm/Analysis/LoopPass.h" 21 #include "llvm/Analysis/MemoryLocation.h" 22 #include "llvm/Analysis/ScalarEvolution.h" 23 #include "llvm/Analysis/ScalarEvolutionExpander.h" 24 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 25 #include "llvm/Analysis/TargetLibraryInfo.h" 26 #include "llvm/Analysis/ValueTracking.h" 27 #include "llvm/IR/Attributes.h" 28 #include "llvm/IR/BasicBlock.h" 29 #include "llvm/IR/Constant.h" 30 #include "llvm/IR/Constants.h" 31 #include "llvm/IR/DataLayout.h" 32 #include "llvm/IR/DebugLoc.h" 33 #include "llvm/IR/DerivedTypes.h" 34 #include "llvm/IR/Dominators.h" 35 #include "llvm/IR/Function.h" 36 #include "llvm/IR/IRBuilder.h" 37 #include "llvm/IR/InstrTypes.h" 38 #include "llvm/IR/Instruction.h" 39 #include "llvm/IR/Instructions.h" 40 #include "llvm/IR/IntrinsicInst.h" 41 #include "llvm/IR/Intrinsics.h" 42 #include "llvm/IR/IntrinsicsHexagon.h" 43 #include "llvm/IR/Module.h" 44 #include "llvm/IR/PatternMatch.h" 45 #include "llvm/IR/Type.h" 46 #include "llvm/IR/User.h" 47 #include "llvm/IR/Value.h" 48 #include "llvm/InitializePasses.h" 49 #include "llvm/Pass.h" 50 #include "llvm/Support/Casting.h" 51 #include "llvm/Support/CommandLine.h" 52 #include "llvm/Support/Compiler.h" 53 #include "llvm/Support/Debug.h" 54 #include "llvm/Support/ErrorHandling.h" 55 #include "llvm/Support/KnownBits.h" 56 #include "llvm/Support/raw_ostream.h" 57 #include "llvm/Transforms/Scalar.h" 58 #include "llvm/Transforms/Utils.h" 59 #include "llvm/Transforms/Utils/Local.h" 60 #include <algorithm> 61 #include <array> 62 #include <cassert> 63 #include <cstdint> 64 #include <cstdlib> 65 #include <deque> 66 #include <functional> 67 #include <iterator> 68 #include <map> 69 #include <set> 70 #include <utility> 71 #include <vector> 72 73 #define DEBUG_TYPE "hexagon-lir" 74 75 using namespace llvm; 76 77 static cl::opt<bool> DisableMemcpyIdiom("disable-memcpy-idiom", 78 cl::Hidden, cl::init(false), 79 cl::desc("Disable generation of memcpy in loop idiom recognition")); 80 81 static cl::opt<bool> DisableMemmoveIdiom("disable-memmove-idiom", 82 cl::Hidden, cl::init(false), 83 cl::desc("Disable generation of memmove in loop idiom recognition")); 84 85 static cl::opt<unsigned> RuntimeMemSizeThreshold("runtime-mem-idiom-threshold", 86 cl::Hidden, cl::init(0), cl::desc("Threshold (in bytes) for the runtime " 87 "check guarding the memmove.")); 88 89 static cl::opt<unsigned> CompileTimeMemSizeThreshold( 90 "compile-time-mem-idiom-threshold", cl::Hidden, cl::init(64), 91 cl::desc("Threshold (in bytes) to perform the transformation, if the " 92 "runtime loop count (mem transfer size) is known at compile-time.")); 93 94 static cl::opt<bool> OnlyNonNestedMemmove("only-nonnested-memmove-idiom", 95 cl::Hidden, cl::init(true), 96 cl::desc("Only enable generating memmove in non-nested loops")); 97 98 static cl::opt<bool> HexagonVolatileMemcpy( 99 "disable-hexagon-volatile-memcpy", cl::Hidden, cl::init(false), 100 cl::desc("Enable Hexagon-specific memcpy for volatile destination.")); 101 102 static cl::opt<unsigned> SimplifyLimit("hlir-simplify-limit", cl::init(10000), 103 cl::Hidden, cl::desc("Maximum number of simplification steps in HLIR")); 104 105 static const char *HexagonVolatileMemcpyName 106 = "hexagon_memcpy_forward_vp4cp4n2"; 107 108 109 namespace llvm { 110 111 void initializeHexagonLoopIdiomRecognizePass(PassRegistry&); 112 Pass *createHexagonLoopIdiomPass(); 113 114 } // end namespace llvm 115 116 namespace { 117 118 class HexagonLoopIdiomRecognize : public LoopPass { 119 public: 120 static char ID; 121 122 explicit HexagonLoopIdiomRecognize() : LoopPass(ID) { 123 initializeHexagonLoopIdiomRecognizePass(*PassRegistry::getPassRegistry()); 124 } 125 126 StringRef getPassName() const override { 127 return "Recognize Hexagon-specific loop idioms"; 128 } 129 130 void getAnalysisUsage(AnalysisUsage &AU) const override { 131 AU.addRequired<LoopInfoWrapperPass>(); 132 AU.addRequiredID(LoopSimplifyID); 133 AU.addRequiredID(LCSSAID); 134 AU.addRequired<AAResultsWrapperPass>(); 135 AU.addPreserved<AAResultsWrapperPass>(); 136 AU.addRequired<ScalarEvolutionWrapperPass>(); 137 AU.addRequired<DominatorTreeWrapperPass>(); 138 AU.addRequired<TargetLibraryInfoWrapperPass>(); 139 AU.addPreserved<TargetLibraryInfoWrapperPass>(); 140 } 141 142 bool runOnLoop(Loop *L, LPPassManager &LPM) override; 143 144 private: 145 int getSCEVStride(const SCEVAddRecExpr *StoreEv); 146 bool isLegalStore(Loop *CurLoop, StoreInst *SI); 147 void collectStores(Loop *CurLoop, BasicBlock *BB, 148 SmallVectorImpl<StoreInst*> &Stores); 149 bool processCopyingStore(Loop *CurLoop, StoreInst *SI, const SCEV *BECount); 150 bool coverLoop(Loop *L, SmallVectorImpl<Instruction*> &Insts) const; 151 bool runOnLoopBlock(Loop *CurLoop, BasicBlock *BB, const SCEV *BECount, 152 SmallVectorImpl<BasicBlock*> &ExitBlocks); 153 bool runOnCountableLoop(Loop *L); 154 155 AliasAnalysis *AA; 156 const DataLayout *DL; 157 DominatorTree *DT; 158 LoopInfo *LF; 159 const TargetLibraryInfo *TLI; 160 ScalarEvolution *SE; 161 bool HasMemcpy, HasMemmove; 162 }; 163 164 struct Simplifier { 165 struct Rule { 166 using FuncType = std::function<Value* (Instruction*, LLVMContext&)>; 167 Rule(StringRef N, FuncType F) : Name(N), Fn(F) {} 168 StringRef Name; // For debugging. 169 FuncType Fn; 170 }; 171 172 void addRule(StringRef N, const Rule::FuncType &F) { 173 Rules.push_back(Rule(N, F)); 174 } 175 176 private: 177 struct WorkListType { 178 WorkListType() = default; 179 180 void push_back(Value* V) { 181 // Do not push back duplicates. 182 if (!S.count(V)) { Q.push_back(V); S.insert(V); } 183 } 184 185 Value *pop_front_val() { 186 Value *V = Q.front(); Q.pop_front(); S.erase(V); 187 return V; 188 } 189 190 bool empty() const { return Q.empty(); } 191 192 private: 193 std::deque<Value*> Q; 194 std::set<Value*> S; 195 }; 196 197 using ValueSetType = std::set<Value *>; 198 199 std::vector<Rule> Rules; 200 201 public: 202 struct Context { 203 using ValueMapType = DenseMap<Value *, Value *>; 204 205 Value *Root; 206 ValueSetType Used; // The set of all cloned values used by Root. 207 ValueSetType Clones; // The set of all cloned values. 208 LLVMContext &Ctx; 209 210 Context(Instruction *Exp) 211 : Ctx(Exp->getParent()->getParent()->getContext()) { 212 initialize(Exp); 213 } 214 215 ~Context() { cleanup(); } 216 217 void print(raw_ostream &OS, const Value *V) const; 218 Value *materialize(BasicBlock *B, BasicBlock::iterator At); 219 220 private: 221 friend struct Simplifier; 222 223 void initialize(Instruction *Exp); 224 void cleanup(); 225 226 template <typename FuncT> void traverse(Value *V, FuncT F); 227 void record(Value *V); 228 void use(Value *V); 229 void unuse(Value *V); 230 231 bool equal(const Instruction *I, const Instruction *J) const; 232 Value *find(Value *Tree, Value *Sub) const; 233 Value *subst(Value *Tree, Value *OldV, Value *NewV); 234 void replace(Value *OldV, Value *NewV); 235 void link(Instruction *I, BasicBlock *B, BasicBlock::iterator At); 236 }; 237 238 Value *simplify(Context &C); 239 }; 240 241 struct PE { 242 PE(const Simplifier::Context &c, Value *v = nullptr) : C(c), V(v) {} 243 244 const Simplifier::Context &C; 245 const Value *V; 246 }; 247 248 LLVM_ATTRIBUTE_USED 249 raw_ostream &operator<<(raw_ostream &OS, const PE &P) { 250 P.C.print(OS, P.V ? P.V : P.C.Root); 251 return OS; 252 } 253 254 } // end anonymous namespace 255 256 char HexagonLoopIdiomRecognize::ID = 0; 257 258 INITIALIZE_PASS_BEGIN(HexagonLoopIdiomRecognize, "hexagon-loop-idiom", 259 "Recognize Hexagon-specific loop idioms", false, false) 260 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 261 INITIALIZE_PASS_DEPENDENCY(LoopSimplify) 262 INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) 263 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) 264 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 265 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 266 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) 267 INITIALIZE_PASS_END(HexagonLoopIdiomRecognize, "hexagon-loop-idiom", 268 "Recognize Hexagon-specific loop idioms", false, false) 269 270 template <typename FuncT> 271 void Simplifier::Context::traverse(Value *V, FuncT F) { 272 WorkListType Q; 273 Q.push_back(V); 274 275 while (!Q.empty()) { 276 Instruction *U = dyn_cast<Instruction>(Q.pop_front_val()); 277 if (!U || U->getParent()) 278 continue; 279 if (!F(U)) 280 continue; 281 for (Value *Op : U->operands()) 282 Q.push_back(Op); 283 } 284 } 285 286 void Simplifier::Context::print(raw_ostream &OS, const Value *V) const { 287 const auto *U = dyn_cast<const Instruction>(V); 288 if (!U) { 289 OS << V << '(' << *V << ')'; 290 return; 291 } 292 293 if (U->getParent()) { 294 OS << U << '('; 295 U->printAsOperand(OS, true); 296 OS << ')'; 297 return; 298 } 299 300 unsigned N = U->getNumOperands(); 301 if (N != 0) 302 OS << U << '('; 303 OS << U->getOpcodeName(); 304 for (const Value *Op : U->operands()) { 305 OS << ' '; 306 print(OS, Op); 307 } 308 if (N != 0) 309 OS << ')'; 310 } 311 312 void Simplifier::Context::initialize(Instruction *Exp) { 313 // Perform a deep clone of the expression, set Root to the root 314 // of the clone, and build a map from the cloned values to the 315 // original ones. 316 ValueMapType M; 317 BasicBlock *Block = Exp->getParent(); 318 WorkListType Q; 319 Q.push_back(Exp); 320 321 while (!Q.empty()) { 322 Value *V = Q.pop_front_val(); 323 if (M.find(V) != M.end()) 324 continue; 325 if (Instruction *U = dyn_cast<Instruction>(V)) { 326 if (isa<PHINode>(U) || U->getParent() != Block) 327 continue; 328 for (Value *Op : U->operands()) 329 Q.push_back(Op); 330 M.insert({U, U->clone()}); 331 } 332 } 333 334 for (std::pair<Value*,Value*> P : M) { 335 Instruction *U = cast<Instruction>(P.second); 336 for (unsigned i = 0, n = U->getNumOperands(); i != n; ++i) { 337 auto F = M.find(U->getOperand(i)); 338 if (F != M.end()) 339 U->setOperand(i, F->second); 340 } 341 } 342 343 auto R = M.find(Exp); 344 assert(R != M.end()); 345 Root = R->second; 346 347 record(Root); 348 use(Root); 349 } 350 351 void Simplifier::Context::record(Value *V) { 352 auto Record = [this](Instruction *U) -> bool { 353 Clones.insert(U); 354 return true; 355 }; 356 traverse(V, Record); 357 } 358 359 void Simplifier::Context::use(Value *V) { 360 auto Use = [this](Instruction *U) -> bool { 361 Used.insert(U); 362 return true; 363 }; 364 traverse(V, Use); 365 } 366 367 void Simplifier::Context::unuse(Value *V) { 368 if (!isa<Instruction>(V) || cast<Instruction>(V)->getParent() != nullptr) 369 return; 370 371 auto Unuse = [this](Instruction *U) -> bool { 372 if (!U->use_empty()) 373 return false; 374 Used.erase(U); 375 return true; 376 }; 377 traverse(V, Unuse); 378 } 379 380 Value *Simplifier::Context::subst(Value *Tree, Value *OldV, Value *NewV) { 381 if (Tree == OldV) 382 return NewV; 383 if (OldV == NewV) 384 return Tree; 385 386 WorkListType Q; 387 Q.push_back(Tree); 388 while (!Q.empty()) { 389 Instruction *U = dyn_cast<Instruction>(Q.pop_front_val()); 390 // If U is not an instruction, or it's not a clone, skip it. 391 if (!U || U->getParent()) 392 continue; 393 for (unsigned i = 0, n = U->getNumOperands(); i != n; ++i) { 394 Value *Op = U->getOperand(i); 395 if (Op == OldV) { 396 U->setOperand(i, NewV); 397 unuse(OldV); 398 } else { 399 Q.push_back(Op); 400 } 401 } 402 } 403 return Tree; 404 } 405 406 void Simplifier::Context::replace(Value *OldV, Value *NewV) { 407 if (Root == OldV) { 408 Root = NewV; 409 use(Root); 410 return; 411 } 412 413 // NewV may be a complex tree that has just been created by one of the 414 // transformation rules. We need to make sure that it is commoned with 415 // the existing Root to the maximum extent possible. 416 // Identify all subtrees of NewV (including NewV itself) that have 417 // equivalent counterparts in Root, and replace those subtrees with 418 // these counterparts. 419 WorkListType Q; 420 Q.push_back(NewV); 421 while (!Q.empty()) { 422 Value *V = Q.pop_front_val(); 423 Instruction *U = dyn_cast<Instruction>(V); 424 if (!U || U->getParent()) 425 continue; 426 if (Value *DupV = find(Root, V)) { 427 if (DupV != V) 428 NewV = subst(NewV, V, DupV); 429 } else { 430 for (Value *Op : U->operands()) 431 Q.push_back(Op); 432 } 433 } 434 435 // Now, simply replace OldV with NewV in Root. 436 Root = subst(Root, OldV, NewV); 437 use(Root); 438 } 439 440 void Simplifier::Context::cleanup() { 441 for (Value *V : Clones) { 442 Instruction *U = cast<Instruction>(V); 443 if (!U->getParent()) 444 U->dropAllReferences(); 445 } 446 447 for (Value *V : Clones) { 448 Instruction *U = cast<Instruction>(V); 449 if (!U->getParent()) 450 U->deleteValue(); 451 } 452 } 453 454 bool Simplifier::Context::equal(const Instruction *I, 455 const Instruction *J) const { 456 if (I == J) 457 return true; 458 if (!I->isSameOperationAs(J)) 459 return false; 460 if (isa<PHINode>(I)) 461 return I->isIdenticalTo(J); 462 463 for (unsigned i = 0, n = I->getNumOperands(); i != n; ++i) { 464 Value *OpI = I->getOperand(i), *OpJ = J->getOperand(i); 465 if (OpI == OpJ) 466 continue; 467 auto *InI = dyn_cast<const Instruction>(OpI); 468 auto *InJ = dyn_cast<const Instruction>(OpJ); 469 if (InI && InJ) { 470 if (!equal(InI, InJ)) 471 return false; 472 } else if (InI != InJ || !InI) 473 return false; 474 } 475 return true; 476 } 477 478 Value *Simplifier::Context::find(Value *Tree, Value *Sub) const { 479 Instruction *SubI = dyn_cast<Instruction>(Sub); 480 WorkListType Q; 481 Q.push_back(Tree); 482 483 while (!Q.empty()) { 484 Value *V = Q.pop_front_val(); 485 if (V == Sub) 486 return V; 487 Instruction *U = dyn_cast<Instruction>(V); 488 if (!U || U->getParent()) 489 continue; 490 if (SubI && equal(SubI, U)) 491 return U; 492 assert(!isa<PHINode>(U)); 493 for (Value *Op : U->operands()) 494 Q.push_back(Op); 495 } 496 return nullptr; 497 } 498 499 void Simplifier::Context::link(Instruction *I, BasicBlock *B, 500 BasicBlock::iterator At) { 501 if (I->getParent()) 502 return; 503 504 for (Value *Op : I->operands()) { 505 if (Instruction *OpI = dyn_cast<Instruction>(Op)) 506 link(OpI, B, At); 507 } 508 509 B->getInstList().insert(At, I); 510 } 511 512 Value *Simplifier::Context::materialize(BasicBlock *B, 513 BasicBlock::iterator At) { 514 if (Instruction *RootI = dyn_cast<Instruction>(Root)) 515 link(RootI, B, At); 516 return Root; 517 } 518 519 Value *Simplifier::simplify(Context &C) { 520 WorkListType Q; 521 Q.push_back(C.Root); 522 unsigned Count = 0; 523 const unsigned Limit = SimplifyLimit; 524 525 while (!Q.empty()) { 526 if (Count++ >= Limit) 527 break; 528 Instruction *U = dyn_cast<Instruction>(Q.pop_front_val()); 529 if (!U || U->getParent() || !C.Used.count(U)) 530 continue; 531 bool Changed = false; 532 for (Rule &R : Rules) { 533 Value *W = R.Fn(U, C.Ctx); 534 if (!W) 535 continue; 536 Changed = true; 537 C.record(W); 538 C.replace(U, W); 539 Q.push_back(C.Root); 540 break; 541 } 542 if (!Changed) { 543 for (Value *Op : U->operands()) 544 Q.push_back(Op); 545 } 546 } 547 return Count < Limit ? C.Root : nullptr; 548 } 549 550 //===----------------------------------------------------------------------===// 551 // 552 // Implementation of PolynomialMultiplyRecognize 553 // 554 //===----------------------------------------------------------------------===// 555 556 namespace { 557 558 class PolynomialMultiplyRecognize { 559 public: 560 explicit PolynomialMultiplyRecognize(Loop *loop, const DataLayout &dl, 561 const DominatorTree &dt, const TargetLibraryInfo &tli, 562 ScalarEvolution &se) 563 : CurLoop(loop), DL(dl), DT(dt), TLI(tli), SE(se) {} 564 565 bool recognize(); 566 567 private: 568 using ValueSeq = SetVector<Value *>; 569 570 IntegerType *getPmpyType() const { 571 LLVMContext &Ctx = CurLoop->getHeader()->getParent()->getContext(); 572 return IntegerType::get(Ctx, 32); 573 } 574 575 bool isPromotableTo(Value *V, IntegerType *Ty); 576 void promoteTo(Instruction *In, IntegerType *DestTy, BasicBlock *LoopB); 577 bool promoteTypes(BasicBlock *LoopB, BasicBlock *ExitB); 578 579 Value *getCountIV(BasicBlock *BB); 580 bool findCycle(Value *Out, Value *In, ValueSeq &Cycle); 581 void classifyCycle(Instruction *DivI, ValueSeq &Cycle, ValueSeq &Early, 582 ValueSeq &Late); 583 bool classifyInst(Instruction *UseI, ValueSeq &Early, ValueSeq &Late); 584 bool commutesWithShift(Instruction *I); 585 bool highBitsAreZero(Value *V, unsigned IterCount); 586 bool keepsHighBitsZero(Value *V, unsigned IterCount); 587 bool isOperandShifted(Instruction *I, Value *Op); 588 bool convertShiftsToLeft(BasicBlock *LoopB, BasicBlock *ExitB, 589 unsigned IterCount); 590 void cleanupLoopBody(BasicBlock *LoopB); 591 592 struct ParsedValues { 593 ParsedValues() = default; 594 595 Value *M = nullptr; 596 Value *P = nullptr; 597 Value *Q = nullptr; 598 Value *R = nullptr; 599 Value *X = nullptr; 600 Instruction *Res = nullptr; 601 unsigned IterCount = 0; 602 bool Left = false; 603 bool Inv = false; 604 }; 605 606 bool matchLeftShift(SelectInst *SelI, Value *CIV, ParsedValues &PV); 607 bool matchRightShift(SelectInst *SelI, ParsedValues &PV); 608 bool scanSelect(SelectInst *SI, BasicBlock *LoopB, BasicBlock *PrehB, 609 Value *CIV, ParsedValues &PV, bool PreScan); 610 unsigned getInverseMxN(unsigned QP); 611 Value *generate(BasicBlock::iterator At, ParsedValues &PV); 612 613 void setupPreSimplifier(Simplifier &S); 614 void setupPostSimplifier(Simplifier &S); 615 616 Loop *CurLoop; 617 const DataLayout &DL; 618 const DominatorTree &DT; 619 const TargetLibraryInfo &TLI; 620 ScalarEvolution &SE; 621 }; 622 623 } // end anonymous namespace 624 625 Value *PolynomialMultiplyRecognize::getCountIV(BasicBlock *BB) { 626 pred_iterator PI = pred_begin(BB), PE = pred_end(BB); 627 if (std::distance(PI, PE) != 2) 628 return nullptr; 629 BasicBlock *PB = (*PI == BB) ? *std::next(PI) : *PI; 630 631 for (auto I = BB->begin(), E = BB->end(); I != E && isa<PHINode>(I); ++I) { 632 auto *PN = cast<PHINode>(I); 633 Value *InitV = PN->getIncomingValueForBlock(PB); 634 if (!isa<ConstantInt>(InitV) || !cast<ConstantInt>(InitV)->isZero()) 635 continue; 636 Value *IterV = PN->getIncomingValueForBlock(BB); 637 auto *BO = dyn_cast<BinaryOperator>(IterV); 638 if (!BO) 639 continue; 640 if (BO->getOpcode() != Instruction::Add) 641 continue; 642 Value *IncV = nullptr; 643 if (BO->getOperand(0) == PN) 644 IncV = BO->getOperand(1); 645 else if (BO->getOperand(1) == PN) 646 IncV = BO->getOperand(0); 647 if (IncV == nullptr) 648 continue; 649 650 if (auto *T = dyn_cast<ConstantInt>(IncV)) 651 if (T->getZExtValue() == 1) 652 return PN; 653 } 654 return nullptr; 655 } 656 657 static void replaceAllUsesOfWithIn(Value *I, Value *J, BasicBlock *BB) { 658 for (auto UI = I->user_begin(), UE = I->user_end(); UI != UE;) { 659 Use &TheUse = UI.getUse(); 660 ++UI; 661 if (auto *II = dyn_cast<Instruction>(TheUse.getUser())) 662 if (BB == II->getParent()) 663 II->replaceUsesOfWith(I, J); 664 } 665 } 666 667 bool PolynomialMultiplyRecognize::matchLeftShift(SelectInst *SelI, 668 Value *CIV, ParsedValues &PV) { 669 // Match the following: 670 // select (X & (1 << i)) != 0 ? R ^ (Q << i) : R 671 // select (X & (1 << i)) == 0 ? R : R ^ (Q << i) 672 // The condition may also check for equality with the masked value, i.e 673 // select (X & (1 << i)) == (1 << i) ? R ^ (Q << i) : R 674 // select (X & (1 << i)) != (1 << i) ? R : R ^ (Q << i); 675 676 Value *CondV = SelI->getCondition(); 677 Value *TrueV = SelI->getTrueValue(); 678 Value *FalseV = SelI->getFalseValue(); 679 680 using namespace PatternMatch; 681 682 CmpInst::Predicate P; 683 Value *A = nullptr, *B = nullptr, *C = nullptr; 684 685 if (!match(CondV, m_ICmp(P, m_And(m_Value(A), m_Value(B)), m_Value(C))) && 686 !match(CondV, m_ICmp(P, m_Value(C), m_And(m_Value(A), m_Value(B))))) 687 return false; 688 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE) 689 return false; 690 // Matched: select (A & B) == C ? ... : ... 691 // select (A & B) != C ? ... : ... 692 693 Value *X = nullptr, *Sh1 = nullptr; 694 // Check (A & B) for (X & (1 << i)): 695 if (match(A, m_Shl(m_One(), m_Specific(CIV)))) { 696 Sh1 = A; 697 X = B; 698 } else if (match(B, m_Shl(m_One(), m_Specific(CIV)))) { 699 Sh1 = B; 700 X = A; 701 } else { 702 // TODO: Could also check for an induction variable containing single 703 // bit shifted left by 1 in each iteration. 704 return false; 705 } 706 707 bool TrueIfZero; 708 709 // Check C against the possible values for comparison: 0 and (1 << i): 710 if (match(C, m_Zero())) 711 TrueIfZero = (P == CmpInst::ICMP_EQ); 712 else if (C == Sh1) 713 TrueIfZero = (P == CmpInst::ICMP_NE); 714 else 715 return false; 716 717 // So far, matched: 718 // select (X & (1 << i)) ? ... : ... 719 // including variations of the check against zero/non-zero value. 720 721 Value *ShouldSameV = nullptr, *ShouldXoredV = nullptr; 722 if (TrueIfZero) { 723 ShouldSameV = TrueV; 724 ShouldXoredV = FalseV; 725 } else { 726 ShouldSameV = FalseV; 727 ShouldXoredV = TrueV; 728 } 729 730 Value *Q = nullptr, *R = nullptr, *Y = nullptr, *Z = nullptr; 731 Value *T = nullptr; 732 if (match(ShouldXoredV, m_Xor(m_Value(Y), m_Value(Z)))) { 733 // Matched: select +++ ? ... : Y ^ Z 734 // select +++ ? Y ^ Z : ... 735 // where +++ denotes previously checked matches. 736 if (ShouldSameV == Y) 737 T = Z; 738 else if (ShouldSameV == Z) 739 T = Y; 740 else 741 return false; 742 R = ShouldSameV; 743 // Matched: select +++ ? R : R ^ T 744 // select +++ ? R ^ T : R 745 // depending on TrueIfZero. 746 747 } else if (match(ShouldSameV, m_Zero())) { 748 // Matched: select +++ ? 0 : ... 749 // select +++ ? ... : 0 750 if (!SelI->hasOneUse()) 751 return false; 752 T = ShouldXoredV; 753 // Matched: select +++ ? 0 : T 754 // select +++ ? T : 0 755 756 Value *U = *SelI->user_begin(); 757 if (!match(U, m_Xor(m_Specific(SelI), m_Value(R))) && 758 !match(U, m_Xor(m_Value(R), m_Specific(SelI)))) 759 return false; 760 // Matched: xor (select +++ ? 0 : T), R 761 // xor (select +++ ? T : 0), R 762 } else 763 return false; 764 765 // The xor input value T is isolated into its own match so that it could 766 // be checked against an induction variable containing a shifted bit 767 // (todo). 768 // For now, check against (Q << i). 769 if (!match(T, m_Shl(m_Value(Q), m_Specific(CIV))) && 770 !match(T, m_Shl(m_ZExt(m_Value(Q)), m_ZExt(m_Specific(CIV))))) 771 return false; 772 // Matched: select +++ ? R : R ^ (Q << i) 773 // select +++ ? R ^ (Q << i) : R 774 775 PV.X = X; 776 PV.Q = Q; 777 PV.R = R; 778 PV.Left = true; 779 return true; 780 } 781 782 bool PolynomialMultiplyRecognize::matchRightShift(SelectInst *SelI, 783 ParsedValues &PV) { 784 // Match the following: 785 // select (X & 1) != 0 ? (R >> 1) ^ Q : (R >> 1) 786 // select (X & 1) == 0 ? (R >> 1) : (R >> 1) ^ Q 787 // The condition may also check for equality with the masked value, i.e 788 // select (X & 1) == 1 ? (R >> 1) ^ Q : (R >> 1) 789 // select (X & 1) != 1 ? (R >> 1) : (R >> 1) ^ Q 790 791 Value *CondV = SelI->getCondition(); 792 Value *TrueV = SelI->getTrueValue(); 793 Value *FalseV = SelI->getFalseValue(); 794 795 using namespace PatternMatch; 796 797 Value *C = nullptr; 798 CmpInst::Predicate P; 799 bool TrueIfZero; 800 801 if (match(CondV, m_ICmp(P, m_Value(C), m_Zero())) || 802 match(CondV, m_ICmp(P, m_Zero(), m_Value(C)))) { 803 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE) 804 return false; 805 // Matched: select C == 0 ? ... : ... 806 // select C != 0 ? ... : ... 807 TrueIfZero = (P == CmpInst::ICMP_EQ); 808 } else if (match(CondV, m_ICmp(P, m_Value(C), m_One())) || 809 match(CondV, m_ICmp(P, m_One(), m_Value(C)))) { 810 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE) 811 return false; 812 // Matched: select C == 1 ? ... : ... 813 // select C != 1 ? ... : ... 814 TrueIfZero = (P == CmpInst::ICMP_NE); 815 } else 816 return false; 817 818 Value *X = nullptr; 819 if (!match(C, m_And(m_Value(X), m_One())) && 820 !match(C, m_And(m_One(), m_Value(X)))) 821 return false; 822 // Matched: select (X & 1) == +++ ? ... : ... 823 // select (X & 1) != +++ ? ... : ... 824 825 Value *R = nullptr, *Q = nullptr; 826 if (TrueIfZero) { 827 // The select's condition is true if the tested bit is 0. 828 // TrueV must be the shift, FalseV must be the xor. 829 if (!match(TrueV, m_LShr(m_Value(R), m_One()))) 830 return false; 831 // Matched: select +++ ? (R >> 1) : ... 832 if (!match(FalseV, m_Xor(m_Specific(TrueV), m_Value(Q))) && 833 !match(FalseV, m_Xor(m_Value(Q), m_Specific(TrueV)))) 834 return false; 835 // Matched: select +++ ? (R >> 1) : (R >> 1) ^ Q 836 // with commuting ^. 837 } else { 838 // The select's condition is true if the tested bit is 1. 839 // TrueV must be the xor, FalseV must be the shift. 840 if (!match(FalseV, m_LShr(m_Value(R), m_One()))) 841 return false; 842 // Matched: select +++ ? ... : (R >> 1) 843 if (!match(TrueV, m_Xor(m_Specific(FalseV), m_Value(Q))) && 844 !match(TrueV, m_Xor(m_Value(Q), m_Specific(FalseV)))) 845 return false; 846 // Matched: select +++ ? (R >> 1) ^ Q : (R >> 1) 847 // with commuting ^. 848 } 849 850 PV.X = X; 851 PV.Q = Q; 852 PV.R = R; 853 PV.Left = false; 854 return true; 855 } 856 857 bool PolynomialMultiplyRecognize::scanSelect(SelectInst *SelI, 858 BasicBlock *LoopB, BasicBlock *PrehB, Value *CIV, ParsedValues &PV, 859 bool PreScan) { 860 using namespace PatternMatch; 861 862 // The basic pattern for R = P.Q is: 863 // for i = 0..31 864 // R = phi (0, R') 865 // if (P & (1 << i)) ; test-bit(P, i) 866 // R' = R ^ (Q << i) 867 // 868 // Similarly, the basic pattern for R = (P/Q).Q - P 869 // for i = 0..31 870 // R = phi(P, R') 871 // if (R & (1 << i)) 872 // R' = R ^ (Q << i) 873 874 // There exist idioms, where instead of Q being shifted left, P is shifted 875 // right. This produces a result that is shifted right by 32 bits (the 876 // non-shifted result is 64-bit). 877 // 878 // For R = P.Q, this would be: 879 // for i = 0..31 880 // R = phi (0, R') 881 // if ((P >> i) & 1) 882 // R' = (R >> 1) ^ Q ; R is cycled through the loop, so it must 883 // else ; be shifted by 1, not i. 884 // R' = R >> 1 885 // 886 // And for the inverse: 887 // for i = 0..31 888 // R = phi (P, R') 889 // if (R & 1) 890 // R' = (R >> 1) ^ Q 891 // else 892 // R' = R >> 1 893 894 // The left-shifting idioms share the same pattern: 895 // select (X & (1 << i)) ? R ^ (Q << i) : R 896 // Similarly for right-shifting idioms: 897 // select (X & 1) ? (R >> 1) ^ Q 898 899 if (matchLeftShift(SelI, CIV, PV)) { 900 // If this is a pre-scan, getting this far is sufficient. 901 if (PreScan) 902 return true; 903 904 // Need to make sure that the SelI goes back into R. 905 auto *RPhi = dyn_cast<PHINode>(PV.R); 906 if (!RPhi) 907 return false; 908 if (SelI != RPhi->getIncomingValueForBlock(LoopB)) 909 return false; 910 PV.Res = SelI; 911 912 // If X is loop invariant, it must be the input polynomial, and the 913 // idiom is the basic polynomial multiply. 914 if (CurLoop->isLoopInvariant(PV.X)) { 915 PV.P = PV.X; 916 PV.Inv = false; 917 } else { 918 // X is not loop invariant. If X == R, this is the inverse pmpy. 919 // Otherwise, check for an xor with an invariant value. If the 920 // variable argument to the xor is R, then this is still a valid 921 // inverse pmpy. 922 PV.Inv = true; 923 if (PV.X != PV.R) { 924 Value *Var = nullptr, *Inv = nullptr, *X1 = nullptr, *X2 = nullptr; 925 if (!match(PV.X, m_Xor(m_Value(X1), m_Value(X2)))) 926 return false; 927 auto *I1 = dyn_cast<Instruction>(X1); 928 auto *I2 = dyn_cast<Instruction>(X2); 929 if (!I1 || I1->getParent() != LoopB) { 930 Var = X2; 931 Inv = X1; 932 } else if (!I2 || I2->getParent() != LoopB) { 933 Var = X1; 934 Inv = X2; 935 } else 936 return false; 937 if (Var != PV.R) 938 return false; 939 PV.M = Inv; 940 } 941 // The input polynomial P still needs to be determined. It will be 942 // the entry value of R. 943 Value *EntryP = RPhi->getIncomingValueForBlock(PrehB); 944 PV.P = EntryP; 945 } 946 947 return true; 948 } 949 950 if (matchRightShift(SelI, PV)) { 951 // If this is an inverse pattern, the Q polynomial must be known at 952 // compile time. 953 if (PV.Inv && !isa<ConstantInt>(PV.Q)) 954 return false; 955 if (PreScan) 956 return true; 957 // There is no exact matching of right-shift pmpy. 958 return false; 959 } 960 961 return false; 962 } 963 964 bool PolynomialMultiplyRecognize::isPromotableTo(Value *Val, 965 IntegerType *DestTy) { 966 IntegerType *T = dyn_cast<IntegerType>(Val->getType()); 967 if (!T || T->getBitWidth() > DestTy->getBitWidth()) 968 return false; 969 if (T->getBitWidth() == DestTy->getBitWidth()) 970 return true; 971 // Non-instructions are promotable. The reason why an instruction may not 972 // be promotable is that it may produce a different result if its operands 973 // and the result are promoted, for example, it may produce more non-zero 974 // bits. While it would still be possible to represent the proper result 975 // in a wider type, it may require adding additional instructions (which 976 // we don't want to do). 977 Instruction *In = dyn_cast<Instruction>(Val); 978 if (!In) 979 return true; 980 // The bitwidth of the source type is smaller than the destination. 981 // Check if the individual operation can be promoted. 982 switch (In->getOpcode()) { 983 case Instruction::PHI: 984 case Instruction::ZExt: 985 case Instruction::And: 986 case Instruction::Or: 987 case Instruction::Xor: 988 case Instruction::LShr: // Shift right is ok. 989 case Instruction::Select: 990 case Instruction::Trunc: 991 return true; 992 case Instruction::ICmp: 993 if (CmpInst *CI = cast<CmpInst>(In)) 994 return CI->isEquality() || CI->isUnsigned(); 995 llvm_unreachable("Cast failed unexpectedly"); 996 case Instruction::Add: 997 return In->hasNoSignedWrap() && In->hasNoUnsignedWrap(); 998 } 999 return false; 1000 } 1001 1002 void PolynomialMultiplyRecognize::promoteTo(Instruction *In, 1003 IntegerType *DestTy, BasicBlock *LoopB) { 1004 Type *OrigTy = In->getType(); 1005 assert(!OrigTy->isVoidTy() && "Invalid instruction to promote"); 1006 1007 // Leave boolean values alone. 1008 if (!In->getType()->isIntegerTy(1)) 1009 In->mutateType(DestTy); 1010 unsigned DestBW = DestTy->getBitWidth(); 1011 1012 // Handle PHIs. 1013 if (PHINode *P = dyn_cast<PHINode>(In)) { 1014 unsigned N = P->getNumIncomingValues(); 1015 for (unsigned i = 0; i != N; ++i) { 1016 BasicBlock *InB = P->getIncomingBlock(i); 1017 if (InB == LoopB) 1018 continue; 1019 Value *InV = P->getIncomingValue(i); 1020 IntegerType *Ty = cast<IntegerType>(InV->getType()); 1021 // Do not promote values in PHI nodes of type i1. 1022 if (Ty != P->getType()) { 1023 // If the value type does not match the PHI type, the PHI type 1024 // must have been promoted. 1025 assert(Ty->getBitWidth() < DestBW); 1026 InV = IRBuilder<>(InB->getTerminator()).CreateZExt(InV, DestTy); 1027 P->setIncomingValue(i, InV); 1028 } 1029 } 1030 } else if (ZExtInst *Z = dyn_cast<ZExtInst>(In)) { 1031 Value *Op = Z->getOperand(0); 1032 if (Op->getType() == Z->getType()) 1033 Z->replaceAllUsesWith(Op); 1034 Z->eraseFromParent(); 1035 return; 1036 } 1037 if (TruncInst *T = dyn_cast<TruncInst>(In)) { 1038 IntegerType *TruncTy = cast<IntegerType>(OrigTy); 1039 Value *Mask = ConstantInt::get(DestTy, (1u << TruncTy->getBitWidth()) - 1); 1040 Value *And = IRBuilder<>(In).CreateAnd(T->getOperand(0), Mask); 1041 T->replaceAllUsesWith(And); 1042 T->eraseFromParent(); 1043 return; 1044 } 1045 1046 // Promote immediates. 1047 for (unsigned i = 0, n = In->getNumOperands(); i != n; ++i) { 1048 if (ConstantInt *CI = dyn_cast<ConstantInt>(In->getOperand(i))) 1049 if (CI->getType()->getBitWidth() < DestBW) 1050 In->setOperand(i, ConstantInt::get(DestTy, CI->getZExtValue())); 1051 } 1052 } 1053 1054 bool PolynomialMultiplyRecognize::promoteTypes(BasicBlock *LoopB, 1055 BasicBlock *ExitB) { 1056 assert(LoopB); 1057 // Skip loops where the exit block has more than one predecessor. The values 1058 // coming from the loop block will be promoted to another type, and so the 1059 // values coming into the exit block from other predecessors would also have 1060 // to be promoted. 1061 if (!ExitB || (ExitB->getSinglePredecessor() != LoopB)) 1062 return false; 1063 IntegerType *DestTy = getPmpyType(); 1064 // Check if the exit values have types that are no wider than the type 1065 // that we want to promote to. 1066 unsigned DestBW = DestTy->getBitWidth(); 1067 for (PHINode &P : ExitB->phis()) { 1068 if (P.getNumIncomingValues() != 1) 1069 return false; 1070 assert(P.getIncomingBlock(0) == LoopB); 1071 IntegerType *T = dyn_cast<IntegerType>(P.getType()); 1072 if (!T || T->getBitWidth() > DestBW) 1073 return false; 1074 } 1075 1076 // Check all instructions in the loop. 1077 for (Instruction &In : *LoopB) 1078 if (!In.isTerminator() && !isPromotableTo(&In, DestTy)) 1079 return false; 1080 1081 // Perform the promotion. 1082 std::vector<Instruction*> LoopIns; 1083 std::transform(LoopB->begin(), LoopB->end(), std::back_inserter(LoopIns), 1084 [](Instruction &In) { return &In; }); 1085 for (Instruction *In : LoopIns) 1086 if (!In->isTerminator()) 1087 promoteTo(In, DestTy, LoopB); 1088 1089 // Fix up the PHI nodes in the exit block. 1090 Instruction *EndI = ExitB->getFirstNonPHI(); 1091 BasicBlock::iterator End = EndI ? EndI->getIterator() : ExitB->end(); 1092 for (auto I = ExitB->begin(); I != End; ++I) { 1093 PHINode *P = dyn_cast<PHINode>(I); 1094 if (!P) 1095 break; 1096 Type *Ty0 = P->getIncomingValue(0)->getType(); 1097 Type *PTy = P->getType(); 1098 if (PTy != Ty0) { 1099 assert(Ty0 == DestTy); 1100 // In order to create the trunc, P must have the promoted type. 1101 P->mutateType(Ty0); 1102 Value *T = IRBuilder<>(ExitB, End).CreateTrunc(P, PTy); 1103 // In order for the RAUW to work, the types of P and T must match. 1104 P->mutateType(PTy); 1105 P->replaceAllUsesWith(T); 1106 // Final update of the P's type. 1107 P->mutateType(Ty0); 1108 cast<Instruction>(T)->setOperand(0, P); 1109 } 1110 } 1111 1112 return true; 1113 } 1114 1115 bool PolynomialMultiplyRecognize::findCycle(Value *Out, Value *In, 1116 ValueSeq &Cycle) { 1117 // Out = ..., In, ... 1118 if (Out == In) 1119 return true; 1120 1121 auto *BB = cast<Instruction>(Out)->getParent(); 1122 bool HadPhi = false; 1123 1124 for (auto U : Out->users()) { 1125 auto *I = dyn_cast<Instruction>(&*U); 1126 if (I == nullptr || I->getParent() != BB) 1127 continue; 1128 // Make sure that there are no multi-iteration cycles, e.g. 1129 // p1 = phi(p2) 1130 // p2 = phi(p1) 1131 // The cycle p1->p2->p1 would span two loop iterations. 1132 // Check that there is only one phi in the cycle. 1133 bool IsPhi = isa<PHINode>(I); 1134 if (IsPhi && HadPhi) 1135 return false; 1136 HadPhi |= IsPhi; 1137 if (Cycle.count(I)) 1138 return false; 1139 Cycle.insert(I); 1140 if (findCycle(I, In, Cycle)) 1141 break; 1142 Cycle.remove(I); 1143 } 1144 return !Cycle.empty(); 1145 } 1146 1147 void PolynomialMultiplyRecognize::classifyCycle(Instruction *DivI, 1148 ValueSeq &Cycle, ValueSeq &Early, ValueSeq &Late) { 1149 // All the values in the cycle that are between the phi node and the 1150 // divider instruction will be classified as "early", all other values 1151 // will be "late". 1152 1153 bool IsE = true; 1154 unsigned I, N = Cycle.size(); 1155 for (I = 0; I < N; ++I) { 1156 Value *V = Cycle[I]; 1157 if (DivI == V) 1158 IsE = false; 1159 else if (!isa<PHINode>(V)) 1160 continue; 1161 // Stop if found either. 1162 break; 1163 } 1164 // "I" is the index of either DivI or the phi node, whichever was first. 1165 // "E" is "false" or "true" respectively. 1166 ValueSeq &First = !IsE ? Early : Late; 1167 for (unsigned J = 0; J < I; ++J) 1168 First.insert(Cycle[J]); 1169 1170 ValueSeq &Second = IsE ? Early : Late; 1171 Second.insert(Cycle[I]); 1172 for (++I; I < N; ++I) { 1173 Value *V = Cycle[I]; 1174 if (DivI == V || isa<PHINode>(V)) 1175 break; 1176 Second.insert(V); 1177 } 1178 1179 for (; I < N; ++I) 1180 First.insert(Cycle[I]); 1181 } 1182 1183 bool PolynomialMultiplyRecognize::classifyInst(Instruction *UseI, 1184 ValueSeq &Early, ValueSeq &Late) { 1185 // Select is an exception, since the condition value does not have to be 1186 // classified in the same way as the true/false values. The true/false 1187 // values do have to be both early or both late. 1188 if (UseI->getOpcode() == Instruction::Select) { 1189 Value *TV = UseI->getOperand(1), *FV = UseI->getOperand(2); 1190 if (Early.count(TV) || Early.count(FV)) { 1191 if (Late.count(TV) || Late.count(FV)) 1192 return false; 1193 Early.insert(UseI); 1194 } else if (Late.count(TV) || Late.count(FV)) { 1195 if (Early.count(TV) || Early.count(FV)) 1196 return false; 1197 Late.insert(UseI); 1198 } 1199 return true; 1200 } 1201 1202 // Not sure what would be the example of this, but the code below relies 1203 // on having at least one operand. 1204 if (UseI->getNumOperands() == 0) 1205 return true; 1206 1207 bool AE = true, AL = true; 1208 for (auto &I : UseI->operands()) { 1209 if (Early.count(&*I)) 1210 AL = false; 1211 else if (Late.count(&*I)) 1212 AE = false; 1213 } 1214 // If the operands appear "all early" and "all late" at the same time, 1215 // then it means that none of them are actually classified as either. 1216 // This is harmless. 1217 if (AE && AL) 1218 return true; 1219 // Conversely, if they are neither "all early" nor "all late", then 1220 // we have a mixture of early and late operands that is not a known 1221 // exception. 1222 if (!AE && !AL) 1223 return false; 1224 1225 // Check that we have covered the two special cases. 1226 assert(AE != AL); 1227 1228 if (AE) 1229 Early.insert(UseI); 1230 else 1231 Late.insert(UseI); 1232 return true; 1233 } 1234 1235 bool PolynomialMultiplyRecognize::commutesWithShift(Instruction *I) { 1236 switch (I->getOpcode()) { 1237 case Instruction::And: 1238 case Instruction::Or: 1239 case Instruction::Xor: 1240 case Instruction::LShr: 1241 case Instruction::Shl: 1242 case Instruction::Select: 1243 case Instruction::ICmp: 1244 case Instruction::PHI: 1245 break; 1246 default: 1247 return false; 1248 } 1249 return true; 1250 } 1251 1252 bool PolynomialMultiplyRecognize::highBitsAreZero(Value *V, 1253 unsigned IterCount) { 1254 auto *T = dyn_cast<IntegerType>(V->getType()); 1255 if (!T) 1256 return false; 1257 1258 KnownBits Known(T->getBitWidth()); 1259 computeKnownBits(V, Known, DL); 1260 return Known.countMinLeadingZeros() >= IterCount; 1261 } 1262 1263 bool PolynomialMultiplyRecognize::keepsHighBitsZero(Value *V, 1264 unsigned IterCount) { 1265 // Assume that all inputs to the value have the high bits zero. 1266 // Check if the value itself preserves the zeros in the high bits. 1267 if (auto *C = dyn_cast<ConstantInt>(V)) 1268 return C->getValue().countLeadingZeros() >= IterCount; 1269 1270 if (auto *I = dyn_cast<Instruction>(V)) { 1271 switch (I->getOpcode()) { 1272 case Instruction::And: 1273 case Instruction::Or: 1274 case Instruction::Xor: 1275 case Instruction::LShr: 1276 case Instruction::Select: 1277 case Instruction::ICmp: 1278 case Instruction::PHI: 1279 case Instruction::ZExt: 1280 return true; 1281 } 1282 } 1283 1284 return false; 1285 } 1286 1287 bool PolynomialMultiplyRecognize::isOperandShifted(Instruction *I, Value *Op) { 1288 unsigned Opc = I->getOpcode(); 1289 if (Opc == Instruction::Shl || Opc == Instruction::LShr) 1290 return Op != I->getOperand(1); 1291 return true; 1292 } 1293 1294 bool PolynomialMultiplyRecognize::convertShiftsToLeft(BasicBlock *LoopB, 1295 BasicBlock *ExitB, unsigned IterCount) { 1296 Value *CIV = getCountIV(LoopB); 1297 if (CIV == nullptr) 1298 return false; 1299 auto *CIVTy = dyn_cast<IntegerType>(CIV->getType()); 1300 if (CIVTy == nullptr) 1301 return false; 1302 1303 ValueSeq RShifts; 1304 ValueSeq Early, Late, Cycled; 1305 1306 // Find all value cycles that contain logical right shifts by 1. 1307 for (Instruction &I : *LoopB) { 1308 using namespace PatternMatch; 1309 1310 Value *V = nullptr; 1311 if (!match(&I, m_LShr(m_Value(V), m_One()))) 1312 continue; 1313 ValueSeq C; 1314 if (!findCycle(&I, V, C)) 1315 continue; 1316 1317 // Found a cycle. 1318 C.insert(&I); 1319 classifyCycle(&I, C, Early, Late); 1320 Cycled.insert(C.begin(), C.end()); 1321 RShifts.insert(&I); 1322 } 1323 1324 // Find the set of all values affected by the shift cycles, i.e. all 1325 // cycled values, and (recursively) all their users. 1326 ValueSeq Users(Cycled.begin(), Cycled.end()); 1327 for (unsigned i = 0; i < Users.size(); ++i) { 1328 Value *V = Users[i]; 1329 if (!isa<IntegerType>(V->getType())) 1330 return false; 1331 auto *R = cast<Instruction>(V); 1332 // If the instruction does not commute with shifts, the loop cannot 1333 // be unshifted. 1334 if (!commutesWithShift(R)) 1335 return false; 1336 for (auto I = R->user_begin(), E = R->user_end(); I != E; ++I) { 1337 auto *T = cast<Instruction>(*I); 1338 // Skip users from outside of the loop. They will be handled later. 1339 // Also, skip the right-shifts and phi nodes, since they mix early 1340 // and late values. 1341 if (T->getParent() != LoopB || RShifts.count(T) || isa<PHINode>(T)) 1342 continue; 1343 1344 Users.insert(T); 1345 if (!classifyInst(T, Early, Late)) 1346 return false; 1347 } 1348 } 1349 1350 if (Users.empty()) 1351 return false; 1352 1353 // Verify that high bits remain zero. 1354 ValueSeq Internal(Users.begin(), Users.end()); 1355 ValueSeq Inputs; 1356 for (unsigned i = 0; i < Internal.size(); ++i) { 1357 auto *R = dyn_cast<Instruction>(Internal[i]); 1358 if (!R) 1359 continue; 1360 for (Value *Op : R->operands()) { 1361 auto *T = dyn_cast<Instruction>(Op); 1362 if (T && T->getParent() != LoopB) 1363 Inputs.insert(Op); 1364 else 1365 Internal.insert(Op); 1366 } 1367 } 1368 for (Value *V : Inputs) 1369 if (!highBitsAreZero(V, IterCount)) 1370 return false; 1371 for (Value *V : Internal) 1372 if (!keepsHighBitsZero(V, IterCount)) 1373 return false; 1374 1375 // Finally, the work can be done. Unshift each user. 1376 IRBuilder<> IRB(LoopB); 1377 std::map<Value*,Value*> ShiftMap; 1378 1379 using CastMapType = std::map<std::pair<Value *, Type *>, Value *>; 1380 1381 CastMapType CastMap; 1382 1383 auto upcast = [] (CastMapType &CM, IRBuilder<> &IRB, Value *V, 1384 IntegerType *Ty) -> Value* { 1385 auto H = CM.find(std::make_pair(V, Ty)); 1386 if (H != CM.end()) 1387 return H->second; 1388 Value *CV = IRB.CreateIntCast(V, Ty, false); 1389 CM.insert(std::make_pair(std::make_pair(V, Ty), CV)); 1390 return CV; 1391 }; 1392 1393 for (auto I = LoopB->begin(), E = LoopB->end(); I != E; ++I) { 1394 using namespace PatternMatch; 1395 1396 if (isa<PHINode>(I) || !Users.count(&*I)) 1397 continue; 1398 1399 // Match lshr x, 1. 1400 Value *V = nullptr; 1401 if (match(&*I, m_LShr(m_Value(V), m_One()))) { 1402 replaceAllUsesOfWithIn(&*I, V, LoopB); 1403 continue; 1404 } 1405 // For each non-cycled operand, replace it with the corresponding 1406 // value shifted left. 1407 for (auto &J : I->operands()) { 1408 Value *Op = J.get(); 1409 if (!isOperandShifted(&*I, Op)) 1410 continue; 1411 if (Users.count(Op)) 1412 continue; 1413 // Skip shifting zeros. 1414 if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero()) 1415 continue; 1416 // Check if we have already generated a shift for this value. 1417 auto F = ShiftMap.find(Op); 1418 Value *W = (F != ShiftMap.end()) ? F->second : nullptr; 1419 if (W == nullptr) { 1420 IRB.SetInsertPoint(&*I); 1421 // First, the shift amount will be CIV or CIV+1, depending on 1422 // whether the value is early or late. Instead of creating CIV+1, 1423 // do a single shift of the value. 1424 Value *ShAmt = CIV, *ShVal = Op; 1425 auto *VTy = cast<IntegerType>(ShVal->getType()); 1426 auto *ATy = cast<IntegerType>(ShAmt->getType()); 1427 if (Late.count(&*I)) 1428 ShVal = IRB.CreateShl(Op, ConstantInt::get(VTy, 1)); 1429 // Second, the types of the shifted value and the shift amount 1430 // must match. 1431 if (VTy != ATy) { 1432 if (VTy->getBitWidth() < ATy->getBitWidth()) 1433 ShVal = upcast(CastMap, IRB, ShVal, ATy); 1434 else 1435 ShAmt = upcast(CastMap, IRB, ShAmt, VTy); 1436 } 1437 // Ready to generate the shift and memoize it. 1438 W = IRB.CreateShl(ShVal, ShAmt); 1439 ShiftMap.insert(std::make_pair(Op, W)); 1440 } 1441 I->replaceUsesOfWith(Op, W); 1442 } 1443 } 1444 1445 // Update the users outside of the loop to account for having left 1446 // shifts. They would normally be shifted right in the loop, so shift 1447 // them right after the loop exit. 1448 // Take advantage of the loop-closed SSA form, which has all the post- 1449 // loop values in phi nodes. 1450 IRB.SetInsertPoint(ExitB, ExitB->getFirstInsertionPt()); 1451 for (auto P = ExitB->begin(), Q = ExitB->end(); P != Q; ++P) { 1452 if (!isa<PHINode>(P)) 1453 break; 1454 auto *PN = cast<PHINode>(P); 1455 Value *U = PN->getIncomingValueForBlock(LoopB); 1456 if (!Users.count(U)) 1457 continue; 1458 Value *S = IRB.CreateLShr(PN, ConstantInt::get(PN->getType(), IterCount)); 1459 PN->replaceAllUsesWith(S); 1460 // The above RAUW will create 1461 // S = lshr S, IterCount 1462 // so we need to fix it back into 1463 // S = lshr PN, IterCount 1464 cast<User>(S)->replaceUsesOfWith(S, PN); 1465 } 1466 1467 return true; 1468 } 1469 1470 void PolynomialMultiplyRecognize::cleanupLoopBody(BasicBlock *LoopB) { 1471 for (auto &I : *LoopB) 1472 if (Value *SV = SimplifyInstruction(&I, {DL, &TLI, &DT})) 1473 I.replaceAllUsesWith(SV); 1474 1475 for (auto I = LoopB->begin(), N = I; I != LoopB->end(); I = N) { 1476 N = std::next(I); 1477 RecursivelyDeleteTriviallyDeadInstructions(&*I, &TLI); 1478 } 1479 } 1480 1481 unsigned PolynomialMultiplyRecognize::getInverseMxN(unsigned QP) { 1482 // Arrays of coefficients of Q and the inverse, C. 1483 // Q[i] = coefficient at x^i. 1484 std::array<char,32> Q, C; 1485 1486 for (unsigned i = 0; i < 32; ++i) { 1487 Q[i] = QP & 1; 1488 QP >>= 1; 1489 } 1490 assert(Q[0] == 1); 1491 1492 // Find C, such that 1493 // (Q[n]*x^n + ... + Q[1]*x + Q[0]) * (C[n]*x^n + ... + C[1]*x + C[0]) = 1 1494 // 1495 // For it to have a solution, Q[0] must be 1. Since this is Z2[x], the 1496 // operations * and + are & and ^ respectively. 1497 // 1498 // Find C[i] recursively, by comparing i-th coefficient in the product 1499 // with 0 (or 1 for i=0). 1500 // 1501 // C[0] = 1, since C[0] = Q[0], and Q[0] = 1. 1502 C[0] = 1; 1503 for (unsigned i = 1; i < 32; ++i) { 1504 // Solve for C[i] in: 1505 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i]Q[0] = 0 1506 // This is equivalent to 1507 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i] = 0 1508 // which is 1509 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] = C[i] 1510 unsigned T = 0; 1511 for (unsigned j = 0; j < i; ++j) 1512 T = T ^ (C[j] & Q[i-j]); 1513 C[i] = T; 1514 } 1515 1516 unsigned QV = 0; 1517 for (unsigned i = 0; i < 32; ++i) 1518 if (C[i]) 1519 QV |= (1 << i); 1520 1521 return QV; 1522 } 1523 1524 Value *PolynomialMultiplyRecognize::generate(BasicBlock::iterator At, 1525 ParsedValues &PV) { 1526 IRBuilder<> B(&*At); 1527 Module *M = At->getParent()->getParent()->getParent(); 1528 Function *PMF = Intrinsic::getDeclaration(M, Intrinsic::hexagon_M4_pmpyw); 1529 1530 Value *P = PV.P, *Q = PV.Q, *P0 = P; 1531 unsigned IC = PV.IterCount; 1532 1533 if (PV.M != nullptr) 1534 P0 = P = B.CreateXor(P, PV.M); 1535 1536 // Create a bit mask to clear the high bits beyond IterCount. 1537 auto *BMI = ConstantInt::get(P->getType(), APInt::getLowBitsSet(32, IC)); 1538 1539 if (PV.IterCount != 32) 1540 P = B.CreateAnd(P, BMI); 1541 1542 if (PV.Inv) { 1543 auto *QI = dyn_cast<ConstantInt>(PV.Q); 1544 assert(QI && QI->getBitWidth() <= 32); 1545 1546 // Again, clearing bits beyond IterCount. 1547 unsigned M = (1 << PV.IterCount) - 1; 1548 unsigned Tmp = (QI->getZExtValue() | 1) & M; 1549 unsigned QV = getInverseMxN(Tmp) & M; 1550 auto *QVI = ConstantInt::get(QI->getType(), QV); 1551 P = B.CreateCall(PMF, {P, QVI}); 1552 P = B.CreateTrunc(P, QI->getType()); 1553 if (IC != 32) 1554 P = B.CreateAnd(P, BMI); 1555 } 1556 1557 Value *R = B.CreateCall(PMF, {P, Q}); 1558 1559 if (PV.M != nullptr) 1560 R = B.CreateXor(R, B.CreateIntCast(P0, R->getType(), false)); 1561 1562 return R; 1563 } 1564 1565 static bool hasZeroSignBit(const Value *V) { 1566 if (const auto *CI = dyn_cast<const ConstantInt>(V)) 1567 return (CI->getType()->getSignBit() & CI->getSExtValue()) == 0; 1568 const Instruction *I = dyn_cast<const Instruction>(V); 1569 if (!I) 1570 return false; 1571 switch (I->getOpcode()) { 1572 case Instruction::LShr: 1573 if (const auto SI = dyn_cast<const ConstantInt>(I->getOperand(1))) 1574 return SI->getZExtValue() > 0; 1575 return false; 1576 case Instruction::Or: 1577 case Instruction::Xor: 1578 return hasZeroSignBit(I->getOperand(0)) && 1579 hasZeroSignBit(I->getOperand(1)); 1580 case Instruction::And: 1581 return hasZeroSignBit(I->getOperand(0)) || 1582 hasZeroSignBit(I->getOperand(1)); 1583 } 1584 return false; 1585 } 1586 1587 void PolynomialMultiplyRecognize::setupPreSimplifier(Simplifier &S) { 1588 S.addRule("sink-zext", 1589 // Sink zext past bitwise operations. 1590 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1591 if (I->getOpcode() != Instruction::ZExt) 1592 return nullptr; 1593 Instruction *T = dyn_cast<Instruction>(I->getOperand(0)); 1594 if (!T) 1595 return nullptr; 1596 switch (T->getOpcode()) { 1597 case Instruction::And: 1598 case Instruction::Or: 1599 case Instruction::Xor: 1600 break; 1601 default: 1602 return nullptr; 1603 } 1604 IRBuilder<> B(Ctx); 1605 return B.CreateBinOp(cast<BinaryOperator>(T)->getOpcode(), 1606 B.CreateZExt(T->getOperand(0), I->getType()), 1607 B.CreateZExt(T->getOperand(1), I->getType())); 1608 }); 1609 S.addRule("xor/and -> and/xor", 1610 // (xor (and x a) (and y a)) -> (and (xor x y) a) 1611 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1612 if (I->getOpcode() != Instruction::Xor) 1613 return nullptr; 1614 Instruction *And0 = dyn_cast<Instruction>(I->getOperand(0)); 1615 Instruction *And1 = dyn_cast<Instruction>(I->getOperand(1)); 1616 if (!And0 || !And1) 1617 return nullptr; 1618 if (And0->getOpcode() != Instruction::And || 1619 And1->getOpcode() != Instruction::And) 1620 return nullptr; 1621 if (And0->getOperand(1) != And1->getOperand(1)) 1622 return nullptr; 1623 IRBuilder<> B(Ctx); 1624 return B.CreateAnd(B.CreateXor(And0->getOperand(0), And1->getOperand(0)), 1625 And0->getOperand(1)); 1626 }); 1627 S.addRule("sink binop into select", 1628 // (Op (select c x y) z) -> (select c (Op x z) (Op y z)) 1629 // (Op x (select c y z)) -> (select c (Op x y) (Op x z)) 1630 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1631 BinaryOperator *BO = dyn_cast<BinaryOperator>(I); 1632 if (!BO) 1633 return nullptr; 1634 Instruction::BinaryOps Op = BO->getOpcode(); 1635 if (SelectInst *Sel = dyn_cast<SelectInst>(BO->getOperand(0))) { 1636 IRBuilder<> B(Ctx); 1637 Value *X = Sel->getTrueValue(), *Y = Sel->getFalseValue(); 1638 Value *Z = BO->getOperand(1); 1639 return B.CreateSelect(Sel->getCondition(), 1640 B.CreateBinOp(Op, X, Z), 1641 B.CreateBinOp(Op, Y, Z)); 1642 } 1643 if (SelectInst *Sel = dyn_cast<SelectInst>(BO->getOperand(1))) { 1644 IRBuilder<> B(Ctx); 1645 Value *X = BO->getOperand(0); 1646 Value *Y = Sel->getTrueValue(), *Z = Sel->getFalseValue(); 1647 return B.CreateSelect(Sel->getCondition(), 1648 B.CreateBinOp(Op, X, Y), 1649 B.CreateBinOp(Op, X, Z)); 1650 } 1651 return nullptr; 1652 }); 1653 S.addRule("fold select-select", 1654 // (select c (select c x y) z) -> (select c x z) 1655 // (select c x (select c y z)) -> (select c x z) 1656 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1657 SelectInst *Sel = dyn_cast<SelectInst>(I); 1658 if (!Sel) 1659 return nullptr; 1660 IRBuilder<> B(Ctx); 1661 Value *C = Sel->getCondition(); 1662 if (SelectInst *Sel0 = dyn_cast<SelectInst>(Sel->getTrueValue())) { 1663 if (Sel0->getCondition() == C) 1664 return B.CreateSelect(C, Sel0->getTrueValue(), Sel->getFalseValue()); 1665 } 1666 if (SelectInst *Sel1 = dyn_cast<SelectInst>(Sel->getFalseValue())) { 1667 if (Sel1->getCondition() == C) 1668 return B.CreateSelect(C, Sel->getTrueValue(), Sel1->getFalseValue()); 1669 } 1670 return nullptr; 1671 }); 1672 S.addRule("or-signbit -> xor-signbit", 1673 // (or (lshr x 1) 0x800.0) -> (xor (lshr x 1) 0x800.0) 1674 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1675 if (I->getOpcode() != Instruction::Or) 1676 return nullptr; 1677 ConstantInt *Msb = dyn_cast<ConstantInt>(I->getOperand(1)); 1678 if (!Msb || Msb->getZExtValue() != Msb->getType()->getSignBit()) 1679 return nullptr; 1680 if (!hasZeroSignBit(I->getOperand(0))) 1681 return nullptr; 1682 return IRBuilder<>(Ctx).CreateXor(I->getOperand(0), Msb); 1683 }); 1684 S.addRule("sink lshr into binop", 1685 // (lshr (BitOp x y) c) -> (BitOp (lshr x c) (lshr y c)) 1686 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1687 if (I->getOpcode() != Instruction::LShr) 1688 return nullptr; 1689 BinaryOperator *BitOp = dyn_cast<BinaryOperator>(I->getOperand(0)); 1690 if (!BitOp) 1691 return nullptr; 1692 switch (BitOp->getOpcode()) { 1693 case Instruction::And: 1694 case Instruction::Or: 1695 case Instruction::Xor: 1696 break; 1697 default: 1698 return nullptr; 1699 } 1700 IRBuilder<> B(Ctx); 1701 Value *S = I->getOperand(1); 1702 return B.CreateBinOp(BitOp->getOpcode(), 1703 B.CreateLShr(BitOp->getOperand(0), S), 1704 B.CreateLShr(BitOp->getOperand(1), S)); 1705 }); 1706 S.addRule("expose bitop-const", 1707 // (BitOp1 (BitOp2 x a) b) -> (BitOp2 x (BitOp1 a b)) 1708 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1709 auto IsBitOp = [](unsigned Op) -> bool { 1710 switch (Op) { 1711 case Instruction::And: 1712 case Instruction::Or: 1713 case Instruction::Xor: 1714 return true; 1715 } 1716 return false; 1717 }; 1718 BinaryOperator *BitOp1 = dyn_cast<BinaryOperator>(I); 1719 if (!BitOp1 || !IsBitOp(BitOp1->getOpcode())) 1720 return nullptr; 1721 BinaryOperator *BitOp2 = dyn_cast<BinaryOperator>(BitOp1->getOperand(0)); 1722 if (!BitOp2 || !IsBitOp(BitOp2->getOpcode())) 1723 return nullptr; 1724 ConstantInt *CA = dyn_cast<ConstantInt>(BitOp2->getOperand(1)); 1725 ConstantInt *CB = dyn_cast<ConstantInt>(BitOp1->getOperand(1)); 1726 if (!CA || !CB) 1727 return nullptr; 1728 IRBuilder<> B(Ctx); 1729 Value *X = BitOp2->getOperand(0); 1730 return B.CreateBinOp(BitOp2->getOpcode(), X, 1731 B.CreateBinOp(BitOp1->getOpcode(), CA, CB)); 1732 }); 1733 } 1734 1735 void PolynomialMultiplyRecognize::setupPostSimplifier(Simplifier &S) { 1736 S.addRule("(and (xor (and x a) y) b) -> (and (xor x y) b), if b == b&a", 1737 [](Instruction *I, LLVMContext &Ctx) -> Value* { 1738 if (I->getOpcode() != Instruction::And) 1739 return nullptr; 1740 Instruction *Xor = dyn_cast<Instruction>(I->getOperand(0)); 1741 ConstantInt *C0 = dyn_cast<ConstantInt>(I->getOperand(1)); 1742 if (!Xor || !C0) 1743 return nullptr; 1744 if (Xor->getOpcode() != Instruction::Xor) 1745 return nullptr; 1746 Instruction *And0 = dyn_cast<Instruction>(Xor->getOperand(0)); 1747 Instruction *And1 = dyn_cast<Instruction>(Xor->getOperand(1)); 1748 // Pick the first non-null and. 1749 if (!And0 || And0->getOpcode() != Instruction::And) 1750 std::swap(And0, And1); 1751 ConstantInt *C1 = dyn_cast<ConstantInt>(And0->getOperand(1)); 1752 if (!C1) 1753 return nullptr; 1754 uint32_t V0 = C0->getZExtValue(); 1755 uint32_t V1 = C1->getZExtValue(); 1756 if (V0 != (V0 & V1)) 1757 return nullptr; 1758 IRBuilder<> B(Ctx); 1759 return B.CreateAnd(B.CreateXor(And0->getOperand(0), And1), C0); 1760 }); 1761 } 1762 1763 bool PolynomialMultiplyRecognize::recognize() { 1764 LLVM_DEBUG(dbgs() << "Starting PolynomialMultiplyRecognize on loop\n" 1765 << *CurLoop << '\n'); 1766 // Restrictions: 1767 // - The loop must consist of a single block. 1768 // - The iteration count must be known at compile-time. 1769 // - The loop must have an induction variable starting from 0, and 1770 // incremented in each iteration of the loop. 1771 BasicBlock *LoopB = CurLoop->getHeader(); 1772 LLVM_DEBUG(dbgs() << "Loop header:\n" << *LoopB); 1773 1774 if (LoopB != CurLoop->getLoopLatch()) 1775 return false; 1776 BasicBlock *ExitB = CurLoop->getExitBlock(); 1777 if (ExitB == nullptr) 1778 return false; 1779 BasicBlock *EntryB = CurLoop->getLoopPreheader(); 1780 if (EntryB == nullptr) 1781 return false; 1782 1783 unsigned IterCount = 0; 1784 const SCEV *CT = SE.getBackedgeTakenCount(CurLoop); 1785 if (isa<SCEVCouldNotCompute>(CT)) 1786 return false; 1787 if (auto *CV = dyn_cast<SCEVConstant>(CT)) 1788 IterCount = CV->getValue()->getZExtValue() + 1; 1789 1790 Value *CIV = getCountIV(LoopB); 1791 ParsedValues PV; 1792 Simplifier PreSimp; 1793 PV.IterCount = IterCount; 1794 LLVM_DEBUG(dbgs() << "Loop IV: " << *CIV << "\nIterCount: " << IterCount 1795 << '\n'); 1796 1797 setupPreSimplifier(PreSimp); 1798 1799 // Perform a preliminary scan of select instructions to see if any of them 1800 // looks like a generator of the polynomial multiply steps. Assume that a 1801 // loop can only contain a single transformable operation, so stop the 1802 // traversal after the first reasonable candidate was found. 1803 // XXX: Currently this approach can modify the loop before being 100% sure 1804 // that the transformation can be carried out. 1805 bool FoundPreScan = false; 1806 auto FeedsPHI = [LoopB](const Value *V) -> bool { 1807 for (const Value *U : V->users()) { 1808 if (const auto *P = dyn_cast<const PHINode>(U)) 1809 if (P->getParent() == LoopB) 1810 return true; 1811 } 1812 return false; 1813 }; 1814 for (Instruction &In : *LoopB) { 1815 SelectInst *SI = dyn_cast<SelectInst>(&In); 1816 if (!SI || !FeedsPHI(SI)) 1817 continue; 1818 1819 Simplifier::Context C(SI); 1820 Value *T = PreSimp.simplify(C); 1821 SelectInst *SelI = (T && isa<SelectInst>(T)) ? cast<SelectInst>(T) : SI; 1822 LLVM_DEBUG(dbgs() << "scanSelect(pre-scan): " << PE(C, SelI) << '\n'); 1823 if (scanSelect(SelI, LoopB, EntryB, CIV, PV, true)) { 1824 FoundPreScan = true; 1825 if (SelI != SI) { 1826 Value *NewSel = C.materialize(LoopB, SI->getIterator()); 1827 SI->replaceAllUsesWith(NewSel); 1828 RecursivelyDeleteTriviallyDeadInstructions(SI, &TLI); 1829 } 1830 break; 1831 } 1832 } 1833 1834 if (!FoundPreScan) { 1835 LLVM_DEBUG(dbgs() << "Have not found candidates for pmpy\n"); 1836 return false; 1837 } 1838 1839 if (!PV.Left) { 1840 // The right shift version actually only returns the higher bits of 1841 // the result (each iteration discards the LSB). If we want to convert it 1842 // to a left-shifting loop, the working data type must be at least as 1843 // wide as the target's pmpy instruction. 1844 if (!promoteTypes(LoopB, ExitB)) 1845 return false; 1846 // Run post-promotion simplifications. 1847 Simplifier PostSimp; 1848 setupPostSimplifier(PostSimp); 1849 for (Instruction &In : *LoopB) { 1850 SelectInst *SI = dyn_cast<SelectInst>(&In); 1851 if (!SI || !FeedsPHI(SI)) 1852 continue; 1853 Simplifier::Context C(SI); 1854 Value *T = PostSimp.simplify(C); 1855 SelectInst *SelI = dyn_cast_or_null<SelectInst>(T); 1856 if (SelI != SI) { 1857 Value *NewSel = C.materialize(LoopB, SI->getIterator()); 1858 SI->replaceAllUsesWith(NewSel); 1859 RecursivelyDeleteTriviallyDeadInstructions(SI, &TLI); 1860 } 1861 break; 1862 } 1863 1864 if (!convertShiftsToLeft(LoopB, ExitB, IterCount)) 1865 return false; 1866 cleanupLoopBody(LoopB); 1867 } 1868 1869 // Scan the loop again, find the generating select instruction. 1870 bool FoundScan = false; 1871 for (Instruction &In : *LoopB) { 1872 SelectInst *SelI = dyn_cast<SelectInst>(&In); 1873 if (!SelI) 1874 continue; 1875 LLVM_DEBUG(dbgs() << "scanSelect: " << *SelI << '\n'); 1876 FoundScan = scanSelect(SelI, LoopB, EntryB, CIV, PV, false); 1877 if (FoundScan) 1878 break; 1879 } 1880 assert(FoundScan); 1881 1882 LLVM_DEBUG({ 1883 StringRef PP = (PV.M ? "(P+M)" : "P"); 1884 if (!PV.Inv) 1885 dbgs() << "Found pmpy idiom: R = " << PP << ".Q\n"; 1886 else 1887 dbgs() << "Found inverse pmpy idiom: R = (" << PP << "/Q).Q) + " 1888 << PP << "\n"; 1889 dbgs() << " Res:" << *PV.Res << "\n P:" << *PV.P << "\n"; 1890 if (PV.M) 1891 dbgs() << " M:" << *PV.M << "\n"; 1892 dbgs() << " Q:" << *PV.Q << "\n"; 1893 dbgs() << " Iteration count:" << PV.IterCount << "\n"; 1894 }); 1895 1896 BasicBlock::iterator At(EntryB->getTerminator()); 1897 Value *PM = generate(At, PV); 1898 if (PM == nullptr) 1899 return false; 1900 1901 if (PM->getType() != PV.Res->getType()) 1902 PM = IRBuilder<>(&*At).CreateIntCast(PM, PV.Res->getType(), false); 1903 1904 PV.Res->replaceAllUsesWith(PM); 1905 PV.Res->eraseFromParent(); 1906 return true; 1907 } 1908 1909 int HexagonLoopIdiomRecognize::getSCEVStride(const SCEVAddRecExpr *S) { 1910 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getOperand(1))) 1911 return SC->getAPInt().getSExtValue(); 1912 return 0; 1913 } 1914 1915 bool HexagonLoopIdiomRecognize::isLegalStore(Loop *CurLoop, StoreInst *SI) { 1916 // Allow volatile stores if HexagonVolatileMemcpy is enabled. 1917 if (!(SI->isVolatile() && HexagonVolatileMemcpy) && !SI->isSimple()) 1918 return false; 1919 1920 Value *StoredVal = SI->getValueOperand(); 1921 Value *StorePtr = SI->getPointerOperand(); 1922 1923 // Reject stores that are so large that they overflow an unsigned. 1924 uint64_t SizeInBits = DL->getTypeSizeInBits(StoredVal->getType()); 1925 if ((SizeInBits & 7) || (SizeInBits >> 32) != 0) 1926 return false; 1927 1928 // See if the pointer expression is an AddRec like {base,+,1} on the current 1929 // loop, which indicates a strided store. If we have something else, it's a 1930 // random store we can't handle. 1931 auto *StoreEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); 1932 if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine()) 1933 return false; 1934 1935 // Check to see if the stride matches the size of the store. If so, then we 1936 // know that every byte is touched in the loop. 1937 int Stride = getSCEVStride(StoreEv); 1938 if (Stride == 0) 1939 return false; 1940 unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType()); 1941 if (StoreSize != unsigned(std::abs(Stride))) 1942 return false; 1943 1944 // The store must be feeding a non-volatile load. 1945 LoadInst *LI = dyn_cast<LoadInst>(SI->getValueOperand()); 1946 if (!LI || !LI->isSimple()) 1947 return false; 1948 1949 // See if the pointer expression is an AddRec like {base,+,1} on the current 1950 // loop, which indicates a strided load. If we have something else, it's a 1951 // random load we can't handle. 1952 Value *LoadPtr = LI->getPointerOperand(); 1953 auto *LoadEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(LoadPtr)); 1954 if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine()) 1955 return false; 1956 1957 // The store and load must share the same stride. 1958 if (StoreEv->getOperand(1) != LoadEv->getOperand(1)) 1959 return false; 1960 1961 // Success. This store can be converted into a memcpy. 1962 return true; 1963 } 1964 1965 /// mayLoopAccessLocation - Return true if the specified loop might access the 1966 /// specified pointer location, which is a loop-strided access. The 'Access' 1967 /// argument specifies what the verboten forms of access are (read or write). 1968 static bool 1969 mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, 1970 const SCEV *BECount, unsigned StoreSize, 1971 AliasAnalysis &AA, 1972 SmallPtrSetImpl<Instruction *> &Ignored) { 1973 // Get the location that may be stored across the loop. Since the access 1974 // is strided positively through memory, we say that the modified location 1975 // starts at the pointer and has infinite size. 1976 LocationSize AccessSize = LocationSize::unknown(); 1977 1978 // If the loop iterates a fixed number of times, we can refine the access 1979 // size to be exactly the size of the memset, which is (BECount+1)*StoreSize 1980 if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount)) 1981 AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) * 1982 StoreSize); 1983 1984 // TODO: For this to be really effective, we have to dive into the pointer 1985 // operand in the store. Store to &A[i] of 100 will always return may alias 1986 // with store of &A[100], we need to StoreLoc to be "A" with size of 100, 1987 // which will then no-alias a store to &A[100]. 1988 MemoryLocation StoreLoc(Ptr, AccessSize); 1989 1990 for (auto *B : L->blocks()) 1991 for (auto &I : *B) 1992 if (Ignored.count(&I) == 0 && 1993 isModOrRefSet( 1994 intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access))) 1995 return true; 1996 1997 return false; 1998 } 1999 2000 void HexagonLoopIdiomRecognize::collectStores(Loop *CurLoop, BasicBlock *BB, 2001 SmallVectorImpl<StoreInst*> &Stores) { 2002 Stores.clear(); 2003 for (Instruction &I : *BB) 2004 if (StoreInst *SI = dyn_cast<StoreInst>(&I)) 2005 if (isLegalStore(CurLoop, SI)) 2006 Stores.push_back(SI); 2007 } 2008 2009 bool HexagonLoopIdiomRecognize::processCopyingStore(Loop *CurLoop, 2010 StoreInst *SI, const SCEV *BECount) { 2011 assert((SI->isSimple() || (SI->isVolatile() && HexagonVolatileMemcpy)) && 2012 "Expected only non-volatile stores, or Hexagon-specific memcpy" 2013 "to volatile destination."); 2014 2015 Value *StorePtr = SI->getPointerOperand(); 2016 auto *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); 2017 unsigned Stride = getSCEVStride(StoreEv); 2018 unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType()); 2019 if (Stride != StoreSize) 2020 return false; 2021 2022 // See if the pointer expression is an AddRec like {base,+,1} on the current 2023 // loop, which indicates a strided load. If we have something else, it's a 2024 // random load we can't handle. 2025 auto *LI = cast<LoadInst>(SI->getValueOperand()); 2026 auto *LoadEv = cast<SCEVAddRecExpr>(SE->getSCEV(LI->getPointerOperand())); 2027 2028 // The trip count of the loop and the base pointer of the addrec SCEV is 2029 // guaranteed to be loop invariant, which means that it should dominate the 2030 // header. This allows us to insert code for it in the preheader. 2031 BasicBlock *Preheader = CurLoop->getLoopPreheader(); 2032 Instruction *ExpPt = Preheader->getTerminator(); 2033 IRBuilder<> Builder(ExpPt); 2034 SCEVExpander Expander(*SE, *DL, "hexagon-loop-idiom"); 2035 2036 Type *IntPtrTy = Builder.getIntPtrTy(*DL, SI->getPointerAddressSpace()); 2037 2038 // Okay, we have a strided store "p[i]" of a loaded value. We can turn 2039 // this into a memcpy/memmove in the loop preheader now if we want. However, 2040 // this would be unsafe to do if there is anything else in the loop that may 2041 // read or write the memory region we're storing to. For memcpy, this 2042 // includes the load that feeds the stores. Check for an alias by generating 2043 // the base address and checking everything. 2044 Value *StoreBasePtr = Expander.expandCodeFor(StoreEv->getStart(), 2045 Builder.getInt8PtrTy(SI->getPointerAddressSpace()), ExpPt); 2046 Value *LoadBasePtr = nullptr; 2047 2048 bool Overlap = false; 2049 bool DestVolatile = SI->isVolatile(); 2050 Type *BECountTy = BECount->getType(); 2051 2052 if (DestVolatile) { 2053 // The trip count must fit in i32, since it is the type of the "num_words" 2054 // argument to hexagon_memcpy_forward_vp4cp4n2. 2055 if (StoreSize != 4 || DL->getTypeSizeInBits(BECountTy) > 32) { 2056 CleanupAndExit: 2057 // If we generated new code for the base pointer, clean up. 2058 Expander.clear(); 2059 if (StoreBasePtr && (LoadBasePtr != StoreBasePtr)) { 2060 RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI); 2061 StoreBasePtr = nullptr; 2062 } 2063 if (LoadBasePtr) { 2064 RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr, TLI); 2065 LoadBasePtr = nullptr; 2066 } 2067 return false; 2068 } 2069 } 2070 2071 SmallPtrSet<Instruction*, 2> Ignore1; 2072 Ignore1.insert(SI); 2073 if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount, 2074 StoreSize, *AA, Ignore1)) { 2075 // Check if the load is the offending instruction. 2076 Ignore1.insert(LI); 2077 if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, 2078 BECount, StoreSize, *AA, Ignore1)) { 2079 // Still bad. Nothing we can do. 2080 goto CleanupAndExit; 2081 } 2082 // It worked with the load ignored. 2083 Overlap = true; 2084 } 2085 2086 if (!Overlap) { 2087 if (DisableMemcpyIdiom || !HasMemcpy) 2088 goto CleanupAndExit; 2089 } else { 2090 // Don't generate memmove if this function will be inlined. This is 2091 // because the caller will undergo this transformation after inlining. 2092 Function *Func = CurLoop->getHeader()->getParent(); 2093 if (Func->hasFnAttribute(Attribute::AlwaysInline)) 2094 goto CleanupAndExit; 2095 2096 // In case of a memmove, the call to memmove will be executed instead 2097 // of the loop, so we need to make sure that there is nothing else in 2098 // the loop than the load, store and instructions that these two depend 2099 // on. 2100 SmallVector<Instruction*,2> Insts; 2101 Insts.push_back(SI); 2102 Insts.push_back(LI); 2103 if (!coverLoop(CurLoop, Insts)) 2104 goto CleanupAndExit; 2105 2106 if (DisableMemmoveIdiom || !HasMemmove) 2107 goto CleanupAndExit; 2108 bool IsNested = CurLoop->getParentLoop() != nullptr; 2109 if (IsNested && OnlyNonNestedMemmove) 2110 goto CleanupAndExit; 2111 } 2112 2113 // For a memcpy, we have to make sure that the input array is not being 2114 // mutated by the loop. 2115 LoadBasePtr = Expander.expandCodeFor(LoadEv->getStart(), 2116 Builder.getInt8PtrTy(LI->getPointerAddressSpace()), ExpPt); 2117 2118 SmallPtrSet<Instruction*, 2> Ignore2; 2119 Ignore2.insert(SI); 2120 if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount, 2121 StoreSize, *AA, Ignore2)) 2122 goto CleanupAndExit; 2123 2124 // Check the stride. 2125 bool StridePos = getSCEVStride(LoadEv) >= 0; 2126 2127 // Currently, the volatile memcpy only emulates traversing memory forward. 2128 if (!StridePos && DestVolatile) 2129 goto CleanupAndExit; 2130 2131 bool RuntimeCheck = (Overlap || DestVolatile); 2132 2133 BasicBlock *ExitB; 2134 if (RuntimeCheck) { 2135 // The runtime check needs a single exit block. 2136 SmallVector<BasicBlock*, 8> ExitBlocks; 2137 CurLoop->getUniqueExitBlocks(ExitBlocks); 2138 if (ExitBlocks.size() != 1) 2139 goto CleanupAndExit; 2140 ExitB = ExitBlocks[0]; 2141 } 2142 2143 // The # stored bytes is (BECount+1)*Size. Expand the trip count out to 2144 // pointer size if it isn't already. 2145 LLVMContext &Ctx = SI->getContext(); 2146 BECount = SE->getTruncateOrZeroExtend(BECount, IntPtrTy); 2147 DebugLoc DLoc = SI->getDebugLoc(); 2148 2149 const SCEV *NumBytesS = 2150 SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW); 2151 if (StoreSize != 1) 2152 NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize), 2153 SCEV::FlagNUW); 2154 Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntPtrTy, ExpPt); 2155 if (Instruction *In = dyn_cast<Instruction>(NumBytes)) 2156 if (Value *Simp = SimplifyInstruction(In, {*DL, TLI, DT})) 2157 NumBytes = Simp; 2158 2159 CallInst *NewCall; 2160 2161 if (RuntimeCheck) { 2162 unsigned Threshold = RuntimeMemSizeThreshold; 2163 if (ConstantInt *CI = dyn_cast<ConstantInt>(NumBytes)) { 2164 uint64_t C = CI->getZExtValue(); 2165 if (Threshold != 0 && C < Threshold) 2166 goto CleanupAndExit; 2167 if (C < CompileTimeMemSizeThreshold) 2168 goto CleanupAndExit; 2169 } 2170 2171 BasicBlock *Header = CurLoop->getHeader(); 2172 Function *Func = Header->getParent(); 2173 Loop *ParentL = LF->getLoopFor(Preheader); 2174 StringRef HeaderName = Header->getName(); 2175 2176 // Create a new (empty) preheader, and update the PHI nodes in the 2177 // header to use the new preheader. 2178 BasicBlock *NewPreheader = BasicBlock::Create(Ctx, HeaderName+".rtli.ph", 2179 Func, Header); 2180 if (ParentL) 2181 ParentL->addBasicBlockToLoop(NewPreheader, *LF); 2182 IRBuilder<>(NewPreheader).CreateBr(Header); 2183 for (auto &In : *Header) { 2184 PHINode *PN = dyn_cast<PHINode>(&In); 2185 if (!PN) 2186 break; 2187 int bx = PN->getBasicBlockIndex(Preheader); 2188 if (bx >= 0) 2189 PN->setIncomingBlock(bx, NewPreheader); 2190 } 2191 DT->addNewBlock(NewPreheader, Preheader); 2192 DT->changeImmediateDominator(Header, NewPreheader); 2193 2194 // Check for safe conditions to execute memmove. 2195 // If stride is positive, copying things from higher to lower addresses 2196 // is equivalent to memmove. For negative stride, it's the other way 2197 // around. Copying forward in memory with positive stride may not be 2198 // same as memmove since we may be copying values that we just stored 2199 // in some previous iteration. 2200 Value *LA = Builder.CreatePtrToInt(LoadBasePtr, IntPtrTy); 2201 Value *SA = Builder.CreatePtrToInt(StoreBasePtr, IntPtrTy); 2202 Value *LowA = StridePos ? SA : LA; 2203 Value *HighA = StridePos ? LA : SA; 2204 Value *CmpA = Builder.CreateICmpULT(LowA, HighA); 2205 Value *Cond = CmpA; 2206 2207 // Check for distance between pointers. Since the case LowA < HighA 2208 // is checked for above, assume LowA >= HighA. 2209 Value *Dist = Builder.CreateSub(LowA, HighA); 2210 Value *CmpD = Builder.CreateICmpSLE(NumBytes, Dist); 2211 Value *CmpEither = Builder.CreateOr(Cond, CmpD); 2212 Cond = CmpEither; 2213 2214 if (Threshold != 0) { 2215 Type *Ty = NumBytes->getType(); 2216 Value *Thr = ConstantInt::get(Ty, Threshold); 2217 Value *CmpB = Builder.CreateICmpULT(Thr, NumBytes); 2218 Value *CmpBoth = Builder.CreateAnd(Cond, CmpB); 2219 Cond = CmpBoth; 2220 } 2221 BasicBlock *MemmoveB = BasicBlock::Create(Ctx, Header->getName()+".rtli", 2222 Func, NewPreheader); 2223 if (ParentL) 2224 ParentL->addBasicBlockToLoop(MemmoveB, *LF); 2225 Instruction *OldT = Preheader->getTerminator(); 2226 Builder.CreateCondBr(Cond, MemmoveB, NewPreheader); 2227 OldT->eraseFromParent(); 2228 Preheader->setName(Preheader->getName()+".old"); 2229 DT->addNewBlock(MemmoveB, Preheader); 2230 // Find the new immediate dominator of the exit block. 2231 BasicBlock *ExitD = Preheader; 2232 for (auto PI = pred_begin(ExitB), PE = pred_end(ExitB); PI != PE; ++PI) { 2233 BasicBlock *PB = *PI; 2234 ExitD = DT->findNearestCommonDominator(ExitD, PB); 2235 if (!ExitD) 2236 break; 2237 } 2238 // If the prior immediate dominator of ExitB was dominated by the 2239 // old preheader, then the old preheader becomes the new immediate 2240 // dominator. Otherwise don't change anything (because the newly 2241 // added blocks are dominated by the old preheader). 2242 if (ExitD && DT->dominates(Preheader, ExitD)) { 2243 DomTreeNode *BN = DT->getNode(ExitB); 2244 DomTreeNode *DN = DT->getNode(ExitD); 2245 BN->setIDom(DN); 2246 } 2247 2248 // Add a call to memmove to the conditional block. 2249 IRBuilder<> CondBuilder(MemmoveB); 2250 CondBuilder.CreateBr(ExitB); 2251 CondBuilder.SetInsertPoint(MemmoveB->getTerminator()); 2252 2253 if (DestVolatile) { 2254 Type *Int32Ty = Type::getInt32Ty(Ctx); 2255 Type *Int32PtrTy = Type::getInt32PtrTy(Ctx); 2256 Type *VoidTy = Type::getVoidTy(Ctx); 2257 Module *M = Func->getParent(); 2258 FunctionCallee Fn = M->getOrInsertFunction( 2259 HexagonVolatileMemcpyName, VoidTy, Int32PtrTy, Int32PtrTy, Int32Ty); 2260 2261 const SCEV *OneS = SE->getConstant(Int32Ty, 1); 2262 const SCEV *BECount32 = SE->getTruncateOrZeroExtend(BECount, Int32Ty); 2263 const SCEV *NumWordsS = SE->getAddExpr(BECount32, OneS, SCEV::FlagNUW); 2264 Value *NumWords = Expander.expandCodeFor(NumWordsS, Int32Ty, 2265 MemmoveB->getTerminator()); 2266 if (Instruction *In = dyn_cast<Instruction>(NumWords)) 2267 if (Value *Simp = SimplifyInstruction(In, {*DL, TLI, DT})) 2268 NumWords = Simp; 2269 2270 Value *Op0 = (StoreBasePtr->getType() == Int32PtrTy) 2271 ? StoreBasePtr 2272 : CondBuilder.CreateBitCast(StoreBasePtr, Int32PtrTy); 2273 Value *Op1 = (LoadBasePtr->getType() == Int32PtrTy) 2274 ? LoadBasePtr 2275 : CondBuilder.CreateBitCast(LoadBasePtr, Int32PtrTy); 2276 NewCall = CondBuilder.CreateCall(Fn, {Op0, Op1, NumWords}); 2277 } else { 2278 NewCall = CondBuilder.CreateMemMove( 2279 StoreBasePtr, SI->getAlign(), LoadBasePtr, LI->getAlign(), NumBytes); 2280 } 2281 } else { 2282 NewCall = Builder.CreateMemCpy(StoreBasePtr, SI->getAlign(), LoadBasePtr, 2283 LI->getAlign(), NumBytes); 2284 // Okay, the memcpy has been formed. Zap the original store and 2285 // anything that feeds into it. 2286 RecursivelyDeleteTriviallyDeadInstructions(SI, TLI); 2287 } 2288 2289 NewCall->setDebugLoc(DLoc); 2290 2291 LLVM_DEBUG(dbgs() << " Formed " << (Overlap ? "memmove: " : "memcpy: ") 2292 << *NewCall << "\n" 2293 << " from load ptr=" << *LoadEv << " at: " << *LI << "\n" 2294 << " from store ptr=" << *StoreEv << " at: " << *SI 2295 << "\n"); 2296 2297 return true; 2298 } 2299 2300 // Check if the instructions in Insts, together with their dependencies 2301 // cover the loop in the sense that the loop could be safely eliminated once 2302 // the instructions in Insts are removed. 2303 bool HexagonLoopIdiomRecognize::coverLoop(Loop *L, 2304 SmallVectorImpl<Instruction*> &Insts) const { 2305 SmallSet<BasicBlock*,8> LoopBlocks; 2306 for (auto *B : L->blocks()) 2307 LoopBlocks.insert(B); 2308 2309 SetVector<Instruction*> Worklist(Insts.begin(), Insts.end()); 2310 2311 // Collect all instructions from the loop that the instructions in Insts 2312 // depend on (plus their dependencies, etc.). These instructions will 2313 // constitute the expression trees that feed those in Insts, but the trees 2314 // will be limited only to instructions contained in the loop. 2315 for (unsigned i = 0; i < Worklist.size(); ++i) { 2316 Instruction *In = Worklist[i]; 2317 for (auto I = In->op_begin(), E = In->op_end(); I != E; ++I) { 2318 Instruction *OpI = dyn_cast<Instruction>(I); 2319 if (!OpI) 2320 continue; 2321 BasicBlock *PB = OpI->getParent(); 2322 if (!LoopBlocks.count(PB)) 2323 continue; 2324 Worklist.insert(OpI); 2325 } 2326 } 2327 2328 // Scan all instructions in the loop, if any of them have a user outside 2329 // of the loop, or outside of the expressions collected above, then either 2330 // the loop has a side-effect visible outside of it, or there are 2331 // instructions in it that are not involved in the original set Insts. 2332 for (auto *B : L->blocks()) { 2333 for (auto &In : *B) { 2334 if (isa<BranchInst>(In) || isa<DbgInfoIntrinsic>(In)) 2335 continue; 2336 if (!Worklist.count(&In) && In.mayHaveSideEffects()) 2337 return false; 2338 for (auto K : In.users()) { 2339 Instruction *UseI = dyn_cast<Instruction>(K); 2340 if (!UseI) 2341 continue; 2342 BasicBlock *UseB = UseI->getParent(); 2343 if (LF->getLoopFor(UseB) != L) 2344 return false; 2345 } 2346 } 2347 } 2348 2349 return true; 2350 } 2351 2352 /// runOnLoopBlock - Process the specified block, which lives in a counted loop 2353 /// with the specified backedge count. This block is known to be in the current 2354 /// loop and not in any subloops. 2355 bool HexagonLoopIdiomRecognize::runOnLoopBlock(Loop *CurLoop, BasicBlock *BB, 2356 const SCEV *BECount, SmallVectorImpl<BasicBlock*> &ExitBlocks) { 2357 // We can only promote stores in this block if they are unconditionally 2358 // executed in the loop. For a block to be unconditionally executed, it has 2359 // to dominate all the exit blocks of the loop. Verify this now. 2360 auto DominatedByBB = [this,BB] (BasicBlock *EB) -> bool { 2361 return DT->dominates(BB, EB); 2362 }; 2363 if (!all_of(ExitBlocks, DominatedByBB)) 2364 return false; 2365 2366 bool MadeChange = false; 2367 // Look for store instructions, which may be optimized to memset/memcpy. 2368 SmallVector<StoreInst*,8> Stores; 2369 collectStores(CurLoop, BB, Stores); 2370 2371 // Optimize the store into a memcpy, if it feeds an similarly strided load. 2372 for (auto &SI : Stores) 2373 MadeChange |= processCopyingStore(CurLoop, SI, BECount); 2374 2375 return MadeChange; 2376 } 2377 2378 bool HexagonLoopIdiomRecognize::runOnCountableLoop(Loop *L) { 2379 PolynomialMultiplyRecognize PMR(L, *DL, *DT, *TLI, *SE); 2380 if (PMR.recognize()) 2381 return true; 2382 2383 if (!HasMemcpy && !HasMemmove) 2384 return false; 2385 2386 const SCEV *BECount = SE->getBackedgeTakenCount(L); 2387 assert(!isa<SCEVCouldNotCompute>(BECount) && 2388 "runOnCountableLoop() called on a loop without a predictable" 2389 "backedge-taken count"); 2390 2391 SmallVector<BasicBlock *, 8> ExitBlocks; 2392 L->getUniqueExitBlocks(ExitBlocks); 2393 2394 bool Changed = false; 2395 2396 // Scan all the blocks in the loop that are not in subloops. 2397 for (auto *BB : L->getBlocks()) { 2398 // Ignore blocks in subloops. 2399 if (LF->getLoopFor(BB) != L) 2400 continue; 2401 Changed |= runOnLoopBlock(L, BB, BECount, ExitBlocks); 2402 } 2403 2404 return Changed; 2405 } 2406 2407 bool HexagonLoopIdiomRecognize::runOnLoop(Loop *L, LPPassManager &LPM) { 2408 const Module &M = *L->getHeader()->getParent()->getParent(); 2409 if (Triple(M.getTargetTriple()).getArch() != Triple::hexagon) 2410 return false; 2411 2412 if (skipLoop(L)) 2413 return false; 2414 2415 // If the loop could not be converted to canonical form, it must have an 2416 // indirectbr in it, just give up. 2417 if (!L->getLoopPreheader()) 2418 return false; 2419 2420 // Disable loop idiom recognition if the function's name is a common idiom. 2421 StringRef Name = L->getHeader()->getParent()->getName(); 2422 if (Name == "memset" || Name == "memcpy" || Name == "memmove") 2423 return false; 2424 2425 AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); 2426 DL = &L->getHeader()->getModule()->getDataLayout(); 2427 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 2428 LF = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 2429 TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( 2430 *L->getHeader()->getParent()); 2431 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 2432 2433 HasMemcpy = TLI->has(LibFunc_memcpy); 2434 HasMemmove = TLI->has(LibFunc_memmove); 2435 2436 if (SE->hasLoopInvariantBackedgeTakenCount(L)) 2437 return runOnCountableLoop(L); 2438 return false; 2439 } 2440 2441 Pass *llvm::createHexagonLoopIdiomPass() { 2442 return new HexagonLoopIdiomRecognize(); 2443 } 2444