1 //===- TruncInstCombine.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // TruncInstCombine - looks for expression dags post-dominated by TruncInst and 10 // for each eligible dag, it will create a reduced bit-width expression, replace 11 // the old expression with this new one and remove the old expression. 12 // Eligible expression dag is such that: 13 // 1. Contains only supported instructions. 14 // 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value. 15 // 3. Can be evaluated into type with reduced legal bit-width. 16 // 4. All instructions in the dag must not have users outside the dag. 17 // The only exception is for {ZExt, SExt}Inst with operand type equal to 18 // the new reduced type evaluated in (3). 19 // 20 // The motivation for this optimization is that evaluating and expression using 21 // smaller bit-width is preferable, especially for vectorization where we can 22 // fit more values in one vectorized instruction. In addition, this optimization 23 // may decrease the number of cast instructions, but will not increase it. 24 // 25 //===----------------------------------------------------------------------===// 26 27 #include "AggressiveInstCombineInternal.h" 28 #include "llvm/ADT/MapVector.h" 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/Analysis/ConstantFolding.h" 31 #include "llvm/Analysis/TargetLibraryInfo.h" 32 #include "llvm/IR/DataLayout.h" 33 #include "llvm/IR/Dominators.h" 34 #include "llvm/IR/IRBuilder.h" 35 using namespace llvm; 36 37 #define DEBUG_TYPE "aggressive-instcombine" 38 39 /// Given an instruction and a container, it fills all the relevant operands of 40 /// that instruction, with respect to the Trunc expression dag optimizaton. 41 static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) { 42 unsigned Opc = I->getOpcode(); 43 switch (Opc) { 44 case Instruction::Trunc: 45 case Instruction::ZExt: 46 case Instruction::SExt: 47 // These CastInst are considered leaves of the evaluated expression, thus, 48 // their operands are not relevent. 49 break; 50 case Instruction::Add: 51 case Instruction::Sub: 52 case Instruction::Mul: 53 case Instruction::And: 54 case Instruction::Or: 55 case Instruction::Xor: 56 Ops.push_back(I->getOperand(0)); 57 Ops.push_back(I->getOperand(1)); 58 break; 59 default: 60 llvm_unreachable("Unreachable!"); 61 } 62 } 63 64 bool TruncInstCombine::buildTruncExpressionDag() { 65 SmallVector<Value *, 8> Worklist; 66 SmallVector<Instruction *, 8> Stack; 67 // Clear old expression dag. 68 InstInfoMap.clear(); 69 70 Worklist.push_back(CurrentTruncInst->getOperand(0)); 71 72 while (!Worklist.empty()) { 73 Value *Curr = Worklist.back(); 74 75 if (isa<Constant>(Curr)) { 76 Worklist.pop_back(); 77 continue; 78 } 79 80 auto *I = dyn_cast<Instruction>(Curr); 81 if (!I) 82 return false; 83 84 if (!Stack.empty() && Stack.back() == I) { 85 // Already handled all instruction operands, can remove it from both the 86 // Worklist and the Stack, and add it to the instruction info map. 87 Worklist.pop_back(); 88 Stack.pop_back(); 89 // Insert I to the Info map. 90 InstInfoMap.insert(std::make_pair(I, Info())); 91 continue; 92 } 93 94 if (InstInfoMap.count(I)) { 95 Worklist.pop_back(); 96 continue; 97 } 98 99 // Add the instruction to the stack before start handling its operands. 100 Stack.push_back(I); 101 102 unsigned Opc = I->getOpcode(); 103 switch (Opc) { 104 case Instruction::Trunc: 105 case Instruction::ZExt: 106 case Instruction::SExt: 107 // trunc(trunc(x)) -> trunc(x) 108 // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest 109 // trunc(ext(x)) -> trunc(x) if the source type is larger than the new 110 // dest 111 break; 112 case Instruction::Add: 113 case Instruction::Sub: 114 case Instruction::Mul: 115 case Instruction::And: 116 case Instruction::Or: 117 case Instruction::Xor: { 118 SmallVector<Value *, 2> Operands; 119 getRelevantOperands(I, Operands); 120 for (Value *Operand : Operands) 121 Worklist.push_back(Operand); 122 break; 123 } 124 default: 125 // TODO: Can handle more cases here: 126 // 1. select, shufflevector, extractelement, insertelement 127 // 2. udiv, urem 128 // 3. shl, lshr, ashr 129 // 4. phi node(and loop handling) 130 // ... 131 return false; 132 } 133 } 134 return true; 135 } 136 137 unsigned TruncInstCombine::getMinBitWidth() { 138 SmallVector<Value *, 8> Worklist; 139 SmallVector<Instruction *, 8> Stack; 140 141 Value *Src = CurrentTruncInst->getOperand(0); 142 Type *DstTy = CurrentTruncInst->getType(); 143 unsigned TruncBitWidth = DstTy->getScalarSizeInBits(); 144 unsigned OrigBitWidth = 145 CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); 146 147 if (isa<Constant>(Src)) 148 return TruncBitWidth; 149 150 Worklist.push_back(Src); 151 InstInfoMap[cast<Instruction>(Src)].ValidBitWidth = TruncBitWidth; 152 153 while (!Worklist.empty()) { 154 Value *Curr = Worklist.back(); 155 156 if (isa<Constant>(Curr)) { 157 Worklist.pop_back(); 158 continue; 159 } 160 161 // Otherwise, it must be an instruction. 162 auto *I = cast<Instruction>(Curr); 163 164 auto &Info = InstInfoMap[I]; 165 166 SmallVector<Value *, 2> Operands; 167 getRelevantOperands(I, Operands); 168 169 if (!Stack.empty() && Stack.back() == I) { 170 // Already handled all instruction operands, can remove it from both, the 171 // Worklist and the Stack, and update MinBitWidth. 172 Worklist.pop_back(); 173 Stack.pop_back(); 174 for (auto *Operand : Operands) 175 if (auto *IOp = dyn_cast<Instruction>(Operand)) 176 Info.MinBitWidth = 177 std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth); 178 continue; 179 } 180 181 // Add the instruction to the stack before start handling its operands. 182 Stack.push_back(I); 183 unsigned ValidBitWidth = Info.ValidBitWidth; 184 185 // Update minimum bit-width before handling its operands. This is required 186 // when the instruction is part of a loop. 187 Info.MinBitWidth = std::max(Info.MinBitWidth, Info.ValidBitWidth); 188 189 for (auto *Operand : Operands) 190 if (auto *IOp = dyn_cast<Instruction>(Operand)) { 191 // If we already calculated the minimum bit-width for this valid 192 // bit-width, or for a smaller valid bit-width, then just keep the 193 // answer we already calculated. 194 unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth; 195 if (IOpBitwidth >= ValidBitWidth) 196 continue; 197 InstInfoMap[IOp].ValidBitWidth = std::max(ValidBitWidth, IOpBitwidth); 198 Worklist.push_back(IOp); 199 } 200 } 201 unsigned MinBitWidth = InstInfoMap.lookup(cast<Instruction>(Src)).MinBitWidth; 202 assert(MinBitWidth >= TruncBitWidth); 203 204 if (MinBitWidth > TruncBitWidth) { 205 // In this case reducing expression with vector type might generate a new 206 // vector type, which is not preferable as it might result in generating 207 // sub-optimal code. 208 if (DstTy->isVectorTy()) 209 return OrigBitWidth; 210 // Use the smallest integer type in the range [MinBitWidth, OrigBitWidth). 211 Type *Ty = DL.getSmallestLegalIntType(DstTy->getContext(), MinBitWidth); 212 // Update minimum bit-width with the new destination type bit-width if 213 // succeeded to find such, otherwise, with original bit-width. 214 MinBitWidth = Ty ? Ty->getScalarSizeInBits() : OrigBitWidth; 215 } else { // MinBitWidth == TruncBitWidth 216 // In this case the expression can be evaluated with the trunc instruction 217 // destination type, and trunc instruction can be omitted. However, we 218 // should not perform the evaluation if the original type is a legal scalar 219 // type and the target type is illegal. 220 bool FromLegal = MinBitWidth == 1 || DL.isLegalInteger(OrigBitWidth); 221 bool ToLegal = MinBitWidth == 1 || DL.isLegalInteger(MinBitWidth); 222 if (!DstTy->isVectorTy() && FromLegal && !ToLegal) 223 return OrigBitWidth; 224 } 225 return MinBitWidth; 226 } 227 228 Type *TruncInstCombine::getBestTruncatedType() { 229 if (!buildTruncExpressionDag()) 230 return nullptr; 231 232 // We don't want to duplicate instructions, which isn't profitable. Thus, we 233 // can't shrink something that has multiple users, unless all users are 234 // post-dominated by the trunc instruction, i.e., were visited during the 235 // expression evaluation. 236 unsigned DesiredBitWidth = 0; 237 for (auto Itr : InstInfoMap) { 238 Instruction *I = Itr.first; 239 if (I->hasOneUse()) 240 continue; 241 bool IsExtInst = (isa<ZExtInst>(I) || isa<SExtInst>(I)); 242 for (auto *U : I->users()) 243 if (auto *UI = dyn_cast<Instruction>(U)) 244 if (UI != CurrentTruncInst && !InstInfoMap.count(UI)) { 245 if (!IsExtInst) 246 return nullptr; 247 // If this is an extension from the dest type, we can eliminate it, 248 // even if it has multiple users. Thus, update the DesiredBitWidth and 249 // validate all extension instructions agrees on same DesiredBitWidth. 250 unsigned ExtInstBitWidth = 251 I->getOperand(0)->getType()->getScalarSizeInBits(); 252 if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth) 253 return nullptr; 254 DesiredBitWidth = ExtInstBitWidth; 255 } 256 } 257 258 unsigned OrigBitWidth = 259 CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); 260 261 // Calculate minimum allowed bit-width allowed for shrinking the currently 262 // visited truncate's operand. 263 unsigned MinBitWidth = getMinBitWidth(); 264 265 // Check that we can shrink to smaller bit-width than original one and that 266 // it is similar to the DesiredBitWidth is such exists. 267 if (MinBitWidth >= OrigBitWidth || 268 (DesiredBitWidth && DesiredBitWidth != MinBitWidth)) 269 return nullptr; 270 271 return IntegerType::get(CurrentTruncInst->getContext(), MinBitWidth); 272 } 273 274 /// Given a reduced scalar type \p Ty and a \p V value, return a reduced type 275 /// for \p V, according to its type, if it vector type, return the vector 276 /// version of \p Ty, otherwise return \p Ty. 277 static Type *getReducedType(Value *V, Type *Ty) { 278 assert(Ty && !Ty->isVectorTy() && "Expect Scalar Type"); 279 if (auto *VTy = dyn_cast<VectorType>(V->getType())) 280 return VectorType::get(Ty, VTy->getNumElements()); 281 return Ty; 282 } 283 284 Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) { 285 Type *Ty = getReducedType(V, SclTy); 286 if (auto *C = dyn_cast<Constant>(V)) { 287 C = ConstantExpr::getIntegerCast(C, Ty, false); 288 // If we got a constantexpr back, try to simplify it with DL info. 289 if (Constant *FoldedC = ConstantFoldConstant(C, DL, &TLI)) 290 C = FoldedC; 291 return C; 292 } 293 294 auto *I = cast<Instruction>(V); 295 Info Entry = InstInfoMap.lookup(I); 296 assert(Entry.NewValue); 297 return Entry.NewValue; 298 } 299 300 void TruncInstCombine::ReduceExpressionDag(Type *SclTy) { 301 for (auto &Itr : InstInfoMap) { // Forward 302 Instruction *I = Itr.first; 303 TruncInstCombine::Info &NodeInfo = Itr.second; 304 305 assert(!NodeInfo.NewValue && "Instruction has been evaluated"); 306 307 IRBuilder<> Builder(I); 308 Value *Res = nullptr; 309 unsigned Opc = I->getOpcode(); 310 switch (Opc) { 311 case Instruction::Trunc: 312 case Instruction::ZExt: 313 case Instruction::SExt: { 314 Type *Ty = getReducedType(I, SclTy); 315 // If the source type of the cast is the type we're trying for then we can 316 // just return the source. There's no need to insert it because it is not 317 // new. 318 if (I->getOperand(0)->getType() == Ty) { 319 assert(!isa<TruncInst>(I) && "Cannot reach here with TruncInst"); 320 NodeInfo.NewValue = I->getOperand(0); 321 continue; 322 } 323 // Otherwise, must be the same type of cast, so just reinsert a new one. 324 // This also handles the case of zext(trunc(x)) -> zext(x). 325 Res = Builder.CreateIntCast(I->getOperand(0), Ty, 326 Opc == Instruction::SExt); 327 328 // Update Worklist entries with new value if needed. 329 // There are three possible changes to the Worklist: 330 // 1. Update Old-TruncInst -> New-TruncInst. 331 // 2. Remove Old-TruncInst (if New node is not TruncInst). 332 // 3. Add New-TruncInst (if Old node was not TruncInst). 333 auto Entry = find(Worklist, I); 334 if (Entry != Worklist.end()) { 335 if (auto *NewCI = dyn_cast<TruncInst>(Res)) 336 *Entry = NewCI; 337 else 338 Worklist.erase(Entry); 339 } else if (auto *NewCI = dyn_cast<TruncInst>(Res)) 340 Worklist.push_back(NewCI); 341 break; 342 } 343 case Instruction::Add: 344 case Instruction::Sub: 345 case Instruction::Mul: 346 case Instruction::And: 347 case Instruction::Or: 348 case Instruction::Xor: { 349 Value *LHS = getReducedOperand(I->getOperand(0), SclTy); 350 Value *RHS = getReducedOperand(I->getOperand(1), SclTy); 351 Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS); 352 break; 353 } 354 default: 355 llvm_unreachable("Unhandled instruction"); 356 } 357 358 NodeInfo.NewValue = Res; 359 if (auto *ResI = dyn_cast<Instruction>(Res)) 360 ResI->takeName(I); 361 } 362 363 Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy); 364 Type *DstTy = CurrentTruncInst->getType(); 365 if (Res->getType() != DstTy) { 366 IRBuilder<> Builder(CurrentTruncInst); 367 Res = Builder.CreateIntCast(Res, DstTy, false); 368 if (auto *ResI = dyn_cast<Instruction>(Res)) 369 ResI->takeName(CurrentTruncInst); 370 } 371 CurrentTruncInst->replaceAllUsesWith(Res); 372 373 // Erase old expression dag, which was replaced by the reduced expression dag. 374 // We iterate backward, which means we visit the instruction before we visit 375 // any of its operands, this way, when we get to the operand, we already 376 // removed the instructions (from the expression dag) that uses it. 377 CurrentTruncInst->eraseFromParent(); 378 for (auto I = InstInfoMap.rbegin(), E = InstInfoMap.rend(); I != E; ++I) { 379 // We still need to check that the instruction has no users before we erase 380 // it, because {SExt, ZExt}Inst Instruction might have other users that was 381 // not reduced, in such case, we need to keep that instruction. 382 if (I->first->use_empty()) 383 I->first->eraseFromParent(); 384 } 385 } 386 387 bool TruncInstCombine::run(Function &F) { 388 bool MadeIRChange = false; 389 390 // Collect all TruncInst in the function into the Worklist for evaluating. 391 for (auto &BB : F) { 392 // Ignore unreachable basic block. 393 if (!DT.isReachableFromEntry(&BB)) 394 continue; 395 for (auto &I : BB) 396 if (auto *CI = dyn_cast<TruncInst>(&I)) 397 Worklist.push_back(CI); 398 } 399 400 // Process all TruncInst in the Worklist, for each instruction: 401 // 1. Check if it dominates an eligible expression dag to be reduced. 402 // 2. Create a reduced expression dag and replace the old one with it. 403 while (!Worklist.empty()) { 404 CurrentTruncInst = Worklist.pop_back_val(); 405 406 if (Type *NewDstSclTy = getBestTruncatedType()) { 407 LLVM_DEBUG( 408 dbgs() << "ICE: TruncInstCombine reducing type of expression dag " 409 "dominated by: " 410 << CurrentTruncInst << '\n'); 411 ReduceExpressionDag(NewDstSclTy); 412 MadeIRChange = true; 413 } 414 } 415 416 return MadeIRChange; 417 } 418