1 //===- TruncInstCombine.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // TruncInstCombine - looks for expression dags post-dominated by TruncInst and 10 // for each eligible dag, it will create a reduced bit-width expression, replace 11 // the old expression with this new one and remove the old expression. 12 // Eligible expression dag is such that: 13 // 1. Contains only supported instructions. 14 // 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value. 15 // 3. Can be evaluated into type with reduced legal bit-width. 16 // 4. All instructions in the dag must not have users outside the dag. 17 // The only exception is for {ZExt, SExt}Inst with operand type equal to 18 // the new reduced type evaluated in (3). 19 // 20 // The motivation for this optimization is that evaluating and expression using 21 // smaller bit-width is preferable, especially for vectorization where we can 22 // fit more values in one vectorized instruction. In addition, this optimization 23 // may decrease the number of cast instructions, but will not increase it. 24 // 25 //===----------------------------------------------------------------------===// 26 27 #include "AggressiveInstCombineInternal.h" 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ADT/Statistic.h" 30 #include "llvm/Analysis/ConstantFolding.h" 31 #include "llvm/Analysis/TargetLibraryInfo.h" 32 #include "llvm/IR/DataLayout.h" 33 #include "llvm/IR/Dominators.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/IR/Instruction.h" 36 #include "llvm/Support/KnownBits.h" 37 38 using namespace llvm; 39 40 #define DEBUG_TYPE "aggressive-instcombine" 41 42 STATISTIC( 43 NumDAGsReduced, 44 "Number of truncations eliminated by reducing bit width of expression DAG"); 45 STATISTIC(NumInstrsReduced, 46 "Number of instructions whose bit width was reduced"); 47 48 /// Given an instruction and a container, it fills all the relevant operands of 49 /// that instruction, with respect to the Trunc expression dag optimizaton. 50 static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) { 51 unsigned Opc = I->getOpcode(); 52 switch (Opc) { 53 case Instruction::Trunc: 54 case Instruction::ZExt: 55 case Instruction::SExt: 56 // These CastInst are considered leaves of the evaluated expression, thus, 57 // their operands are not relevent. 58 break; 59 case Instruction::Add: 60 case Instruction::Sub: 61 case Instruction::Mul: 62 case Instruction::And: 63 case Instruction::Or: 64 case Instruction::Xor: 65 case Instruction::Shl: 66 case Instruction::LShr: 67 case Instruction::AShr: 68 case Instruction::UDiv: 69 case Instruction::URem: 70 case Instruction::InsertElement: 71 Ops.push_back(I->getOperand(0)); 72 Ops.push_back(I->getOperand(1)); 73 break; 74 case Instruction::ExtractElement: 75 Ops.push_back(I->getOperand(0)); 76 break; 77 case Instruction::Select: 78 Ops.push_back(I->getOperand(1)); 79 Ops.push_back(I->getOperand(2)); 80 break; 81 default: 82 llvm_unreachable("Unreachable!"); 83 } 84 } 85 86 bool TruncInstCombine::buildTruncExpressionDag() { 87 SmallVector<Value *, 8> Worklist; 88 SmallVector<Instruction *, 8> Stack; 89 // Clear old expression dag. 90 InstInfoMap.clear(); 91 92 Worklist.push_back(CurrentTruncInst->getOperand(0)); 93 94 while (!Worklist.empty()) { 95 Value *Curr = Worklist.back(); 96 97 if (isa<Constant>(Curr)) { 98 Worklist.pop_back(); 99 continue; 100 } 101 102 auto *I = dyn_cast<Instruction>(Curr); 103 if (!I) 104 return false; 105 106 if (!Stack.empty() && Stack.back() == I) { 107 // Already handled all instruction operands, can remove it from both the 108 // Worklist and the Stack, and add it to the instruction info map. 109 Worklist.pop_back(); 110 Stack.pop_back(); 111 // Insert I to the Info map. 112 InstInfoMap.insert(std::make_pair(I, Info())); 113 continue; 114 } 115 116 if (InstInfoMap.count(I)) { 117 Worklist.pop_back(); 118 continue; 119 } 120 121 // Add the instruction to the stack before start handling its operands. 122 Stack.push_back(I); 123 124 unsigned Opc = I->getOpcode(); 125 switch (Opc) { 126 case Instruction::Trunc: 127 case Instruction::ZExt: 128 case Instruction::SExt: 129 // trunc(trunc(x)) -> trunc(x) 130 // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest 131 // trunc(ext(x)) -> trunc(x) if the source type is larger than the new 132 // dest 133 break; 134 case Instruction::Add: 135 case Instruction::Sub: 136 case Instruction::Mul: 137 case Instruction::And: 138 case Instruction::Or: 139 case Instruction::Xor: 140 case Instruction::Shl: 141 case Instruction::LShr: 142 case Instruction::AShr: 143 case Instruction::UDiv: 144 case Instruction::URem: 145 case Instruction::InsertElement: 146 case Instruction::ExtractElement: 147 case Instruction::Select: { 148 SmallVector<Value *, 2> Operands; 149 getRelevantOperands(I, Operands); 150 append_range(Worklist, Operands); 151 break; 152 } 153 default: 154 // TODO: Can handle more cases here: 155 // 1. shufflevector 156 // 2. sdiv, srem 157 // 3. phi node(and loop handling) 158 // ... 159 return false; 160 } 161 } 162 return true; 163 } 164 165 unsigned TruncInstCombine::getMinBitWidth() { 166 SmallVector<Value *, 8> Worklist; 167 SmallVector<Instruction *, 8> Stack; 168 169 Value *Src = CurrentTruncInst->getOperand(0); 170 Type *DstTy = CurrentTruncInst->getType(); 171 unsigned TruncBitWidth = DstTy->getScalarSizeInBits(); 172 unsigned OrigBitWidth = 173 CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); 174 175 if (isa<Constant>(Src)) 176 return TruncBitWidth; 177 178 Worklist.push_back(Src); 179 InstInfoMap[cast<Instruction>(Src)].ValidBitWidth = TruncBitWidth; 180 181 while (!Worklist.empty()) { 182 Value *Curr = Worklist.back(); 183 184 if (isa<Constant>(Curr)) { 185 Worklist.pop_back(); 186 continue; 187 } 188 189 // Otherwise, it must be an instruction. 190 auto *I = cast<Instruction>(Curr); 191 192 auto &Info = InstInfoMap[I]; 193 194 SmallVector<Value *, 2> Operands; 195 getRelevantOperands(I, Operands); 196 197 if (!Stack.empty() && Stack.back() == I) { 198 // Already handled all instruction operands, can remove it from both, the 199 // Worklist and the Stack, and update MinBitWidth. 200 Worklist.pop_back(); 201 Stack.pop_back(); 202 for (auto *Operand : Operands) 203 if (auto *IOp = dyn_cast<Instruction>(Operand)) 204 Info.MinBitWidth = 205 std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth); 206 continue; 207 } 208 209 // Add the instruction to the stack before start handling its operands. 210 Stack.push_back(I); 211 unsigned ValidBitWidth = Info.ValidBitWidth; 212 213 // Update minimum bit-width before handling its operands. This is required 214 // when the instruction is part of a loop. 215 Info.MinBitWidth = std::max(Info.MinBitWidth, Info.ValidBitWidth); 216 217 for (auto *Operand : Operands) 218 if (auto *IOp = dyn_cast<Instruction>(Operand)) { 219 // If we already calculated the minimum bit-width for this valid 220 // bit-width, or for a smaller valid bit-width, then just keep the 221 // answer we already calculated. 222 unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth; 223 if (IOpBitwidth >= ValidBitWidth) 224 continue; 225 InstInfoMap[IOp].ValidBitWidth = ValidBitWidth; 226 Worklist.push_back(IOp); 227 } 228 } 229 unsigned MinBitWidth = InstInfoMap.lookup(cast<Instruction>(Src)).MinBitWidth; 230 assert(MinBitWidth >= TruncBitWidth); 231 232 if (MinBitWidth > TruncBitWidth) { 233 // In this case reducing expression with vector type might generate a new 234 // vector type, which is not preferable as it might result in generating 235 // sub-optimal code. 236 if (DstTy->isVectorTy()) 237 return OrigBitWidth; 238 // Use the smallest integer type in the range [MinBitWidth, OrigBitWidth). 239 Type *Ty = DL.getSmallestLegalIntType(DstTy->getContext(), MinBitWidth); 240 // Update minimum bit-width with the new destination type bit-width if 241 // succeeded to find such, otherwise, with original bit-width. 242 MinBitWidth = Ty ? Ty->getScalarSizeInBits() : OrigBitWidth; 243 } else { // MinBitWidth == TruncBitWidth 244 // In this case the expression can be evaluated with the trunc instruction 245 // destination type, and trunc instruction can be omitted. However, we 246 // should not perform the evaluation if the original type is a legal scalar 247 // type and the target type is illegal. 248 bool FromLegal = MinBitWidth == 1 || DL.isLegalInteger(OrigBitWidth); 249 bool ToLegal = MinBitWidth == 1 || DL.isLegalInteger(MinBitWidth); 250 if (!DstTy->isVectorTy() && FromLegal && !ToLegal) 251 return OrigBitWidth; 252 } 253 return MinBitWidth; 254 } 255 256 Type *TruncInstCombine::getBestTruncatedType() { 257 if (!buildTruncExpressionDag()) 258 return nullptr; 259 260 // We don't want to duplicate instructions, which isn't profitable. Thus, we 261 // can't shrink something that has multiple users, unless all users are 262 // post-dominated by the trunc instruction, i.e., were visited during the 263 // expression evaluation. 264 unsigned DesiredBitWidth = 0; 265 for (auto Itr : InstInfoMap) { 266 Instruction *I = Itr.first; 267 if (I->hasOneUse()) 268 continue; 269 bool IsExtInst = (isa<ZExtInst>(I) || isa<SExtInst>(I)); 270 for (auto *U : I->users()) 271 if (auto *UI = dyn_cast<Instruction>(U)) 272 if (UI != CurrentTruncInst && !InstInfoMap.count(UI)) { 273 if (!IsExtInst) 274 return nullptr; 275 // If this is an extension from the dest type, we can eliminate it, 276 // even if it has multiple users. Thus, update the DesiredBitWidth and 277 // validate all extension instructions agrees on same DesiredBitWidth. 278 unsigned ExtInstBitWidth = 279 I->getOperand(0)->getType()->getScalarSizeInBits(); 280 if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth) 281 return nullptr; 282 DesiredBitWidth = ExtInstBitWidth; 283 } 284 } 285 286 unsigned OrigBitWidth = 287 CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); 288 289 // Initialize MinBitWidth for shift instructions with the minimum number 290 // that is greater than shift amount (i.e. shift amount + 1). 291 // For `lshr` adjust MinBitWidth so that all potentially truncated 292 // bits of the value-to-be-shifted are zeros. 293 // For `ashr` adjust MinBitWidth so that all potentially truncated 294 // bits of the value-to-be-shifted are sign bits (all zeros or ones) 295 // and even one (first) untruncated bit is sign bit. 296 // Exit early if MinBitWidth is not less than original bitwidth. 297 for (auto &Itr : InstInfoMap) { 298 Instruction *I = Itr.first; 299 if (I->isShift()) { 300 KnownBits KnownRHS = computeKnownBits(I->getOperand(1)); 301 unsigned MinBitWidth = KnownRHS.getMaxValue() 302 .uadd_sat(APInt(OrigBitWidth, 1)) 303 .getLimitedValue(OrigBitWidth); 304 if (MinBitWidth == OrigBitWidth) 305 return nullptr; 306 if (I->getOpcode() == Instruction::LShr) { 307 KnownBits KnownLHS = computeKnownBits(I->getOperand(0)); 308 MinBitWidth = 309 std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits()); 310 } 311 if (I->getOpcode() == Instruction::AShr) { 312 unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0)); 313 MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1); 314 } 315 if (MinBitWidth >= OrigBitWidth) 316 return nullptr; 317 Itr.second.MinBitWidth = MinBitWidth; 318 } 319 if (I->getOpcode() == Instruction::UDiv || 320 I->getOpcode() == Instruction::URem) { 321 unsigned MinBitWidth = 0; 322 for (const auto &Op : I->operands()) { 323 KnownBits Known = computeKnownBits(Op); 324 MinBitWidth = 325 std::max(Known.getMaxValue().getActiveBits(), MinBitWidth); 326 if (MinBitWidth >= OrigBitWidth) 327 return nullptr; 328 } 329 Itr.second.MinBitWidth = MinBitWidth; 330 } 331 } 332 333 // Calculate minimum allowed bit-width allowed for shrinking the currently 334 // visited truncate's operand. 335 unsigned MinBitWidth = getMinBitWidth(); 336 337 // Check that we can shrink to smaller bit-width than original one and that 338 // it is similar to the DesiredBitWidth is such exists. 339 if (MinBitWidth >= OrigBitWidth || 340 (DesiredBitWidth && DesiredBitWidth != MinBitWidth)) 341 return nullptr; 342 343 return IntegerType::get(CurrentTruncInst->getContext(), MinBitWidth); 344 } 345 346 /// Given a reduced scalar type \p Ty and a \p V value, return a reduced type 347 /// for \p V, according to its type, if it vector type, return the vector 348 /// version of \p Ty, otherwise return \p Ty. 349 static Type *getReducedType(Value *V, Type *Ty) { 350 assert(Ty && !Ty->isVectorTy() && "Expect Scalar Type"); 351 if (auto *VTy = dyn_cast<VectorType>(V->getType())) 352 return VectorType::get(Ty, VTy->getElementCount()); 353 return Ty; 354 } 355 356 Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) { 357 Type *Ty = getReducedType(V, SclTy); 358 if (auto *C = dyn_cast<Constant>(V)) { 359 C = ConstantExpr::getIntegerCast(C, Ty, false); 360 // If we got a constantexpr back, try to simplify it with DL info. 361 return ConstantFoldConstant(C, DL, &TLI); 362 } 363 364 auto *I = cast<Instruction>(V); 365 Info Entry = InstInfoMap.lookup(I); 366 assert(Entry.NewValue); 367 return Entry.NewValue; 368 } 369 370 void TruncInstCombine::ReduceExpressionDag(Type *SclTy) { 371 NumInstrsReduced += InstInfoMap.size(); 372 for (auto &Itr : InstInfoMap) { // Forward 373 Instruction *I = Itr.first; 374 TruncInstCombine::Info &NodeInfo = Itr.second; 375 376 assert(!NodeInfo.NewValue && "Instruction has been evaluated"); 377 378 IRBuilder<> Builder(I); 379 Value *Res = nullptr; 380 unsigned Opc = I->getOpcode(); 381 switch (Opc) { 382 case Instruction::Trunc: 383 case Instruction::ZExt: 384 case Instruction::SExt: { 385 Type *Ty = getReducedType(I, SclTy); 386 // If the source type of the cast is the type we're trying for then we can 387 // just return the source. There's no need to insert it because it is not 388 // new. 389 if (I->getOperand(0)->getType() == Ty) { 390 assert(!isa<TruncInst>(I) && "Cannot reach here with TruncInst"); 391 NodeInfo.NewValue = I->getOperand(0); 392 continue; 393 } 394 // Otherwise, must be the same type of cast, so just reinsert a new one. 395 // This also handles the case of zext(trunc(x)) -> zext(x). 396 Res = Builder.CreateIntCast(I->getOperand(0), Ty, 397 Opc == Instruction::SExt); 398 399 // Update Worklist entries with new value if needed. 400 // There are three possible changes to the Worklist: 401 // 1. Update Old-TruncInst -> New-TruncInst. 402 // 2. Remove Old-TruncInst (if New node is not TruncInst). 403 // 3. Add New-TruncInst (if Old node was not TruncInst). 404 auto *Entry = find(Worklist, I); 405 if (Entry != Worklist.end()) { 406 if (auto *NewCI = dyn_cast<TruncInst>(Res)) 407 *Entry = NewCI; 408 else 409 Worklist.erase(Entry); 410 } else if (auto *NewCI = dyn_cast<TruncInst>(Res)) 411 Worklist.push_back(NewCI); 412 break; 413 } 414 case Instruction::Add: 415 case Instruction::Sub: 416 case Instruction::Mul: 417 case Instruction::And: 418 case Instruction::Or: 419 case Instruction::Xor: 420 case Instruction::Shl: 421 case Instruction::LShr: 422 case Instruction::AShr: 423 case Instruction::UDiv: 424 case Instruction::URem: { 425 Value *LHS = getReducedOperand(I->getOperand(0), SclTy); 426 Value *RHS = getReducedOperand(I->getOperand(1), SclTy); 427 Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS); 428 // Preserve `exact` flag since truncation doesn't change exactness 429 if (auto *PEO = dyn_cast<PossiblyExactOperator>(I)) 430 if (auto *ResI = dyn_cast<Instruction>(Res)) 431 ResI->setIsExact(PEO->isExact()); 432 break; 433 } 434 case Instruction::ExtractElement: { 435 Value *Vec = getReducedOperand(I->getOperand(0), SclTy); 436 Value *Idx = I->getOperand(1); 437 Res = Builder.CreateExtractElement(Vec, Idx); 438 break; 439 } 440 case Instruction::InsertElement: { 441 Value *Vec = getReducedOperand(I->getOperand(0), SclTy); 442 Value *NewElt = getReducedOperand(I->getOperand(1), SclTy); 443 Value *Idx = I->getOperand(2); 444 Res = Builder.CreateInsertElement(Vec, NewElt, Idx); 445 break; 446 } 447 case Instruction::Select: { 448 Value *Op0 = I->getOperand(0); 449 Value *LHS = getReducedOperand(I->getOperand(1), SclTy); 450 Value *RHS = getReducedOperand(I->getOperand(2), SclTy); 451 Res = Builder.CreateSelect(Op0, LHS, RHS); 452 break; 453 } 454 default: 455 llvm_unreachable("Unhandled instruction"); 456 } 457 458 NodeInfo.NewValue = Res; 459 if (auto *ResI = dyn_cast<Instruction>(Res)) 460 ResI->takeName(I); 461 } 462 463 Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy); 464 Type *DstTy = CurrentTruncInst->getType(); 465 if (Res->getType() != DstTy) { 466 IRBuilder<> Builder(CurrentTruncInst); 467 Res = Builder.CreateIntCast(Res, DstTy, false); 468 if (auto *ResI = dyn_cast<Instruction>(Res)) 469 ResI->takeName(CurrentTruncInst); 470 } 471 CurrentTruncInst->replaceAllUsesWith(Res); 472 473 // Erase old expression dag, which was replaced by the reduced expression dag. 474 // We iterate backward, which means we visit the instruction before we visit 475 // any of its operands, this way, when we get to the operand, we already 476 // removed the instructions (from the expression dag) that uses it. 477 CurrentTruncInst->eraseFromParent(); 478 for (auto &I : llvm::reverse(InstInfoMap)) { 479 // We still need to check that the instruction has no users before we erase 480 // it, because {SExt, ZExt}Inst Instruction might have other users that was 481 // not reduced, in such case, we need to keep that instruction. 482 if (I.first->use_empty()) 483 I.first->eraseFromParent(); 484 } 485 } 486 487 bool TruncInstCombine::run(Function &F) { 488 bool MadeIRChange = false; 489 490 // Collect all TruncInst in the function into the Worklist for evaluating. 491 for (auto &BB : F) { 492 // Ignore unreachable basic block. 493 if (!DT.isReachableFromEntry(&BB)) 494 continue; 495 for (auto &I : BB) 496 if (auto *CI = dyn_cast<TruncInst>(&I)) 497 Worklist.push_back(CI); 498 } 499 500 // Process all TruncInst in the Worklist, for each instruction: 501 // 1. Check if it dominates an eligible expression dag to be reduced. 502 // 2. Create a reduced expression dag and replace the old one with it. 503 while (!Worklist.empty()) { 504 CurrentTruncInst = Worklist.pop_back_val(); 505 506 if (Type *NewDstSclTy = getBestTruncatedType()) { 507 LLVM_DEBUG( 508 dbgs() << "ICE: TruncInstCombine reducing type of expression dag " 509 "dominated by: " 510 << CurrentTruncInst << '\n'); 511 ReduceExpressionDag(NewDstSclTy); 512 ++NumDAGsReduced; 513 MadeIRChange = true; 514 } 515 } 516 517 return MadeIRChange; 518 } 519