1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// 2 // instrinsics 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This pass replaces masked memory intrinsics - when unsupported by the target 11 // - with a chain of basic blocks, that deal with the elements one-by-one if the 12 // appropriate mask bit is set. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h" 17 #include "llvm/ADT/Twine.h" 18 #include "llvm/Analysis/DomTreeUpdater.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/IR/BasicBlock.h" 21 #include "llvm/IR/Constant.h" 22 #include "llvm/IR/Constants.h" 23 #include "llvm/IR/DerivedTypes.h" 24 #include "llvm/IR/Dominators.h" 25 #include "llvm/IR/Function.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/InstrTypes.h" 28 #include "llvm/IR/Instruction.h" 29 #include "llvm/IR/Instructions.h" 30 #include "llvm/IR/IntrinsicInst.h" 31 #include "llvm/IR/Intrinsics.h" 32 #include "llvm/IR/Type.h" 33 #include "llvm/IR/Value.h" 34 #include "llvm/InitializePasses.h" 35 #include "llvm/Pass.h" 36 #include "llvm/Support/Casting.h" 37 #include "llvm/Transforms/Scalar.h" 38 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 39 #include <algorithm> 40 #include <cassert> 41 42 using namespace llvm; 43 44 #define DEBUG_TYPE "scalarize-masked-mem-intrin" 45 46 namespace { 47 48 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass { 49 public: 50 static char ID; // Pass identification, replacement for typeid 51 52 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) { 53 initializeScalarizeMaskedMemIntrinLegacyPassPass( 54 *PassRegistry::getPassRegistry()); 55 } 56 57 bool runOnFunction(Function &F) override; 58 59 StringRef getPassName() const override { 60 return "Scalarize Masked Memory Intrinsics"; 61 } 62 63 void getAnalysisUsage(AnalysisUsage &AU) const override { 64 AU.addRequired<TargetTransformInfoWrapperPass>(); 65 AU.addPreserved<DominatorTreeWrapperPass>(); 66 } 67 }; 68 69 } // end anonymous namespace 70 71 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, 72 const TargetTransformInfo &TTI, const DataLayout &DL, 73 DomTreeUpdater *DTU); 74 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, 75 const TargetTransformInfo &TTI, 76 const DataLayout &DL, DomTreeUpdater *DTU); 77 78 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0; 79 80 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, 81 "Scalarize unsupported masked memory intrinsics", false, 82 false) 83 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 84 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 85 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, 86 "Scalarize unsupported masked memory intrinsics", false, 87 false) 88 89 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() { 90 return new ScalarizeMaskedMemIntrinLegacyPass(); 91 } 92 93 static bool isConstantIntVector(Value *Mask) { 94 Constant *C = dyn_cast<Constant>(Mask); 95 if (!C) 96 return false; 97 98 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements(); 99 for (unsigned i = 0; i != NumElts; ++i) { 100 Constant *CElt = C->getAggregateElement(i); 101 if (!CElt || !isa<ConstantInt>(CElt)) 102 return false; 103 } 104 105 return true; 106 } 107 108 static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, 109 unsigned Idx) { 110 return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx; 111 } 112 113 // Translate a masked load intrinsic like 114 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, 115 // <16 x i1> %mask, <16 x i32> %passthru) 116 // to a chain of basic blocks, with loading element one-by-one if 117 // the appropriate mask bit is set 118 // 119 // %1 = bitcast i8* %addr to i32* 120 // %2 = extractelement <16 x i1> %mask, i32 0 121 // br i1 %2, label %cond.load, label %else 122 // 123 // cond.load: ; preds = %0 124 // %3 = getelementptr i32* %1, i32 0 125 // %4 = load i32* %3 126 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 127 // br label %else 128 // 129 // else: ; preds = %0, %cond.load 130 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ] 131 // %6 = extractelement <16 x i1> %mask, i32 1 132 // br i1 %6, label %cond.load1, label %else2 133 // 134 // cond.load1: ; preds = %else 135 // %7 = getelementptr i32* %1, i32 1 136 // %8 = load i32* %7 137 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 138 // br label %else2 139 // 140 // else2: ; preds = %else, %cond.load1 141 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] 142 // %10 = extractelement <16 x i1> %mask, i32 2 143 // br i1 %10, label %cond.load4, label %else5 144 // 145 static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI, 146 DomTreeUpdater *DTU, bool &ModifiedDT) { 147 Value *Ptr = CI->getArgOperand(0); 148 Value *Alignment = CI->getArgOperand(1); 149 Value *Mask = CI->getArgOperand(2); 150 Value *Src0 = CI->getArgOperand(3); 151 152 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); 153 VectorType *VecType = cast<FixedVectorType>(CI->getType()); 154 155 Type *EltTy = VecType->getElementType(); 156 157 IRBuilder<> Builder(CI->getContext()); 158 Instruction *InsertPt = CI; 159 BasicBlock *IfBlock = CI->getParent(); 160 161 Builder.SetInsertPoint(InsertPt); 162 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 163 164 // Short-cut if the mask is all-true. 165 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 166 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal); 167 CI->replaceAllUsesWith(NewI); 168 CI->eraseFromParent(); 169 return; 170 } 171 172 // Adjust alignment for the scalar instruction. 173 const Align AdjustedAlignVal = 174 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 175 // Bitcast %addr from i8* to EltTy* 176 Type *NewPtrType = 177 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); 178 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 179 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); 180 181 // The result vector 182 Value *VResult = Src0; 183 184 if (isConstantIntVector(Mask)) { 185 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 186 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 187 continue; 188 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 189 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); 190 VResult = Builder.CreateInsertElement(VResult, Load, Idx); 191 } 192 CI->replaceAllUsesWith(VResult); 193 CI->eraseFromParent(); 194 return; 195 } 196 197 // If the mask is not v1i1, use scalar bit test operations. This generates 198 // better results on X86 at least. 199 Value *SclrMask; 200 if (VectorWidth != 1) { 201 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 202 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 203 } 204 205 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 206 // Fill the "else" block, created in the previous iteration 207 // 208 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 209 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 210 // %cond = icmp ne i16 %mask_1, 0 211 // br i1 %mask_1, label %cond.load, label %else 212 // 213 Value *Predicate; 214 if (VectorWidth != 1) { 215 Value *Mask = Builder.getInt(APInt::getOneBitSet( 216 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 217 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 218 Builder.getIntN(VectorWidth, 0)); 219 } else { 220 Predicate = Builder.CreateExtractElement(Mask, Idx); 221 } 222 223 // Create "cond" block 224 // 225 // %EltAddr = getelementptr i32* %1, i32 0 226 // %Elt = load i32* %EltAddr 227 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 228 // 229 Instruction *ThenTerm = 230 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 231 /*BranchWeights=*/nullptr, DTU); 232 233 BasicBlock *CondBlock = ThenTerm->getParent(); 234 CondBlock->setName("cond.load"); 235 236 Builder.SetInsertPoint(CondBlock->getTerminator()); 237 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 238 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); 239 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 240 241 // Create "else" block, fill it in the next iteration 242 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 243 NewIfBlock->setName("else"); 244 BasicBlock *PrevIfBlock = IfBlock; 245 IfBlock = NewIfBlock; 246 247 // Create the phi to join the new and previous value. 248 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 249 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 250 Phi->addIncoming(NewVResult, CondBlock); 251 Phi->addIncoming(VResult, PrevIfBlock); 252 VResult = Phi; 253 } 254 255 CI->replaceAllUsesWith(VResult); 256 CI->eraseFromParent(); 257 258 ModifiedDT = true; 259 } 260 261 // Translate a masked store intrinsic, like 262 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, 263 // <16 x i1> %mask) 264 // to a chain of basic blocks, that stores element one-by-one if 265 // the appropriate mask bit is set 266 // 267 // %1 = bitcast i8* %addr to i32* 268 // %2 = extractelement <16 x i1> %mask, i32 0 269 // br i1 %2, label %cond.store, label %else 270 // 271 // cond.store: ; preds = %0 272 // %3 = extractelement <16 x i32> %val, i32 0 273 // %4 = getelementptr i32* %1, i32 0 274 // store i32 %3, i32* %4 275 // br label %else 276 // 277 // else: ; preds = %0, %cond.store 278 // %5 = extractelement <16 x i1> %mask, i32 1 279 // br i1 %5, label %cond.store1, label %else2 280 // 281 // cond.store1: ; preds = %else 282 // %6 = extractelement <16 x i32> %val, i32 1 283 // %7 = getelementptr i32* %1, i32 1 284 // store i32 %6, i32* %7 285 // br label %else2 286 // . . . 287 static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI, 288 DomTreeUpdater *DTU, bool &ModifiedDT) { 289 Value *Src = CI->getArgOperand(0); 290 Value *Ptr = CI->getArgOperand(1); 291 Value *Alignment = CI->getArgOperand(2); 292 Value *Mask = CI->getArgOperand(3); 293 294 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); 295 auto *VecType = cast<VectorType>(Src->getType()); 296 297 Type *EltTy = VecType->getElementType(); 298 299 IRBuilder<> Builder(CI->getContext()); 300 Instruction *InsertPt = CI; 301 Builder.SetInsertPoint(InsertPt); 302 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 303 304 // Short-cut if the mask is all-true. 305 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 306 Builder.CreateAlignedStore(Src, Ptr, AlignVal); 307 CI->eraseFromParent(); 308 return; 309 } 310 311 // Adjust alignment for the scalar instruction. 312 const Align AdjustedAlignVal = 313 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 314 // Bitcast %addr from i8* to EltTy* 315 Type *NewPtrType = 316 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); 317 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 318 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); 319 320 if (isConstantIntVector(Mask)) { 321 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 322 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 323 continue; 324 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 325 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 326 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); 327 } 328 CI->eraseFromParent(); 329 return; 330 } 331 332 // If the mask is not v1i1, use scalar bit test operations. This generates 333 // better results on X86 at least. 334 Value *SclrMask; 335 if (VectorWidth != 1) { 336 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 337 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 338 } 339 340 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 341 // Fill the "else" block, created in the previous iteration 342 // 343 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 344 // %cond = icmp ne i16 %mask_1, 0 345 // br i1 %mask_1, label %cond.store, label %else 346 // 347 Value *Predicate; 348 if (VectorWidth != 1) { 349 Value *Mask = Builder.getInt(APInt::getOneBitSet( 350 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 351 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 352 Builder.getIntN(VectorWidth, 0)); 353 } else { 354 Predicate = Builder.CreateExtractElement(Mask, Idx); 355 } 356 357 // Create "cond" block 358 // 359 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 360 // %EltAddr = getelementptr i32* %1, i32 0 361 // %store i32 %OneElt, i32* %EltAddr 362 // 363 Instruction *ThenTerm = 364 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 365 /*BranchWeights=*/nullptr, DTU); 366 367 BasicBlock *CondBlock = ThenTerm->getParent(); 368 CondBlock->setName("cond.store"); 369 370 Builder.SetInsertPoint(CondBlock->getTerminator()); 371 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 372 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 373 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); 374 375 // Create "else" block, fill it in the next iteration 376 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 377 NewIfBlock->setName("else"); 378 379 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 380 } 381 CI->eraseFromParent(); 382 383 ModifiedDT = true; 384 } 385 386 // Translate a masked gather intrinsic like 387 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, 388 // <16 x i1> %Mask, <16 x i32> %Src) 389 // to a chain of basic blocks, with loading element one-by-one if 390 // the appropriate mask bit is set 391 // 392 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind 393 // %Mask0 = extractelement <16 x i1> %Mask, i32 0 394 // br i1 %Mask0, label %cond.load, label %else 395 // 396 // cond.load: 397 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 398 // %Load0 = load i32, i32* %Ptr0, align 4 399 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0 400 // br label %else 401 // 402 // else: 403 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0] 404 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 405 // br i1 %Mask1, label %cond.load1, label %else2 406 // 407 // cond.load1: 408 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 409 // %Load1 = load i32, i32* %Ptr1, align 4 410 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 411 // br label %else2 412 // . . . 413 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src 414 // ret <16 x i32> %Result 415 static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI, 416 DomTreeUpdater *DTU, bool &ModifiedDT) { 417 Value *Ptrs = CI->getArgOperand(0); 418 Value *Alignment = CI->getArgOperand(1); 419 Value *Mask = CI->getArgOperand(2); 420 Value *Src0 = CI->getArgOperand(3); 421 422 auto *VecType = cast<FixedVectorType>(CI->getType()); 423 Type *EltTy = VecType->getElementType(); 424 425 IRBuilder<> Builder(CI->getContext()); 426 Instruction *InsertPt = CI; 427 BasicBlock *IfBlock = CI->getParent(); 428 Builder.SetInsertPoint(InsertPt); 429 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); 430 431 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 432 433 // The result vector 434 Value *VResult = Src0; 435 unsigned VectorWidth = VecType->getNumElements(); 436 437 // Shorten the way if the mask is a vector of constants. 438 if (isConstantIntVector(Mask)) { 439 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 440 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 441 continue; 442 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 443 LoadInst *Load = 444 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 445 VResult = 446 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 447 } 448 CI->replaceAllUsesWith(VResult); 449 CI->eraseFromParent(); 450 return; 451 } 452 453 // If the mask is not v1i1, use scalar bit test operations. This generates 454 // better results on X86 at least. 455 Value *SclrMask; 456 if (VectorWidth != 1) { 457 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 458 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 459 } 460 461 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 462 // Fill the "else" block, created in the previous iteration 463 // 464 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 465 // %cond = icmp ne i16 %mask_1, 0 466 // br i1 %Mask1, label %cond.load, label %else 467 // 468 469 Value *Predicate; 470 if (VectorWidth != 1) { 471 Value *Mask = Builder.getInt(APInt::getOneBitSet( 472 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 473 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 474 Builder.getIntN(VectorWidth, 0)); 475 } else { 476 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 477 } 478 479 // Create "cond" block 480 // 481 // %EltAddr = getelementptr i32* %1, i32 0 482 // %Elt = load i32* %EltAddr 483 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 484 // 485 Instruction *ThenTerm = 486 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 487 /*BranchWeights=*/nullptr, DTU); 488 489 BasicBlock *CondBlock = ThenTerm->getParent(); 490 CondBlock->setName("cond.load"); 491 492 Builder.SetInsertPoint(CondBlock->getTerminator()); 493 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 494 LoadInst *Load = 495 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 496 Value *NewVResult = 497 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 498 499 // Create "else" block, fill it in the next iteration 500 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 501 NewIfBlock->setName("else"); 502 BasicBlock *PrevIfBlock = IfBlock; 503 IfBlock = NewIfBlock; 504 505 // Create the phi to join the new and previous value. 506 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 507 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 508 Phi->addIncoming(NewVResult, CondBlock); 509 Phi->addIncoming(VResult, PrevIfBlock); 510 VResult = Phi; 511 } 512 513 CI->replaceAllUsesWith(VResult); 514 CI->eraseFromParent(); 515 516 ModifiedDT = true; 517 } 518 519 // Translate a masked scatter intrinsic, like 520 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, 521 // <16 x i1> %Mask) 522 // to a chain of basic blocks, that stores element one-by-one if 523 // the appropriate mask bit is set. 524 // 525 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind 526 // %Mask0 = extractelement <16 x i1> %Mask, i32 0 527 // br i1 %Mask0, label %cond.store, label %else 528 // 529 // cond.store: 530 // %Elt0 = extractelement <16 x i32> %Src, i32 0 531 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 532 // store i32 %Elt0, i32* %Ptr0, align 4 533 // br label %else 534 // 535 // else: 536 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 537 // br i1 %Mask1, label %cond.store1, label %else2 538 // 539 // cond.store1: 540 // %Elt1 = extractelement <16 x i32> %Src, i32 1 541 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 542 // store i32 %Elt1, i32* %Ptr1, align 4 543 // br label %else2 544 // . . . 545 static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI, 546 DomTreeUpdater *DTU, bool &ModifiedDT) { 547 Value *Src = CI->getArgOperand(0); 548 Value *Ptrs = CI->getArgOperand(1); 549 Value *Alignment = CI->getArgOperand(2); 550 Value *Mask = CI->getArgOperand(3); 551 552 auto *SrcFVTy = cast<FixedVectorType>(Src->getType()); 553 554 assert( 555 isa<VectorType>(Ptrs->getType()) && 556 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) && 557 "Vector of pointers is expected in masked scatter intrinsic"); 558 559 IRBuilder<> Builder(CI->getContext()); 560 Instruction *InsertPt = CI; 561 Builder.SetInsertPoint(InsertPt); 562 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 563 564 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); 565 unsigned VectorWidth = SrcFVTy->getNumElements(); 566 567 // Shorten the way if the mask is a vector of constants. 568 if (isConstantIntVector(Mask)) { 569 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 570 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 571 continue; 572 Value *OneElt = 573 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 574 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 575 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 576 } 577 CI->eraseFromParent(); 578 return; 579 } 580 581 // If the mask is not v1i1, use scalar bit test operations. This generates 582 // better results on X86 at least. 583 Value *SclrMask; 584 if (VectorWidth != 1) { 585 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 586 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 587 } 588 589 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 590 // Fill the "else" block, created in the previous iteration 591 // 592 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 593 // %cond = icmp ne i16 %mask_1, 0 594 // br i1 %Mask1, label %cond.store, label %else 595 // 596 Value *Predicate; 597 if (VectorWidth != 1) { 598 Value *Mask = Builder.getInt(APInt::getOneBitSet( 599 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 600 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 601 Builder.getIntN(VectorWidth, 0)); 602 } else { 603 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 604 } 605 606 // Create "cond" block 607 // 608 // %Elt1 = extractelement <16 x i32> %Src, i32 1 609 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 610 // %store i32 %Elt1, i32* %Ptr1 611 // 612 Instruction *ThenTerm = 613 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 614 /*BranchWeights=*/nullptr, DTU); 615 616 BasicBlock *CondBlock = ThenTerm->getParent(); 617 CondBlock->setName("cond.store"); 618 619 Builder.SetInsertPoint(CondBlock->getTerminator()); 620 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 621 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 622 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 623 624 // Create "else" block, fill it in the next iteration 625 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 626 NewIfBlock->setName("else"); 627 628 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 629 } 630 CI->eraseFromParent(); 631 632 ModifiedDT = true; 633 } 634 635 static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI, 636 DomTreeUpdater *DTU, bool &ModifiedDT) { 637 Value *Ptr = CI->getArgOperand(0); 638 Value *Mask = CI->getArgOperand(1); 639 Value *PassThru = CI->getArgOperand(2); 640 641 auto *VecType = cast<FixedVectorType>(CI->getType()); 642 643 Type *EltTy = VecType->getElementType(); 644 645 IRBuilder<> Builder(CI->getContext()); 646 Instruction *InsertPt = CI; 647 BasicBlock *IfBlock = CI->getParent(); 648 649 Builder.SetInsertPoint(InsertPt); 650 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 651 652 unsigned VectorWidth = VecType->getNumElements(); 653 654 // The result vector 655 Value *VResult = PassThru; 656 657 // Shorten the way if the mask is a vector of constants. 658 // Create a build_vector pattern, with loads/undefs as necessary and then 659 // shuffle blend with the pass through value. 660 if (isConstantIntVector(Mask)) { 661 unsigned MemIndex = 0; 662 VResult = UndefValue::get(VecType); 663 SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem); 664 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 665 Value *InsertElt; 666 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) { 667 InsertElt = UndefValue::get(EltTy); 668 ShuffleMask[Idx] = Idx + VectorWidth; 669 } else { 670 Value *NewPtr = 671 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 672 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1), 673 "Load" + Twine(Idx)); 674 ShuffleMask[Idx] = Idx; 675 ++MemIndex; 676 } 677 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx, 678 "Res" + Twine(Idx)); 679 } 680 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask); 681 CI->replaceAllUsesWith(VResult); 682 CI->eraseFromParent(); 683 return; 684 } 685 686 // If the mask is not v1i1, use scalar bit test operations. This generates 687 // better results on X86 at least. 688 Value *SclrMask; 689 if (VectorWidth != 1) { 690 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 691 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 692 } 693 694 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 695 // Fill the "else" block, created in the previous iteration 696 // 697 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 698 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 699 // br i1 %mask_1, label %cond.load, label %else 700 // 701 702 Value *Predicate; 703 if (VectorWidth != 1) { 704 Value *Mask = Builder.getInt(APInt::getOneBitSet( 705 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 706 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 707 Builder.getIntN(VectorWidth, 0)); 708 } else { 709 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 710 } 711 712 // Create "cond" block 713 // 714 // %EltAddr = getelementptr i32* %1, i32 0 715 // %Elt = load i32* %EltAddr 716 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 717 // 718 Instruction *ThenTerm = 719 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 720 /*BranchWeights=*/nullptr, DTU); 721 722 BasicBlock *CondBlock = ThenTerm->getParent(); 723 CondBlock->setName("cond.load"); 724 725 Builder.SetInsertPoint(CondBlock->getTerminator()); 726 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1)); 727 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 728 729 // Move the pointer if there are more blocks to come. 730 Value *NewPtr; 731 if ((Idx + 1) != VectorWidth) 732 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 733 734 // Create "else" block, fill it in the next iteration 735 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 736 NewIfBlock->setName("else"); 737 BasicBlock *PrevIfBlock = IfBlock; 738 IfBlock = NewIfBlock; 739 740 // Create the phi to join the new and previous value. 741 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 742 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 743 ResultPhi->addIncoming(NewVResult, CondBlock); 744 ResultPhi->addIncoming(VResult, PrevIfBlock); 745 VResult = ResultPhi; 746 747 // Add a PHI for the pointer if this isn't the last iteration. 748 if ((Idx + 1) != VectorWidth) { 749 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 750 PtrPhi->addIncoming(NewPtr, CondBlock); 751 PtrPhi->addIncoming(Ptr, PrevIfBlock); 752 Ptr = PtrPhi; 753 } 754 } 755 756 CI->replaceAllUsesWith(VResult); 757 CI->eraseFromParent(); 758 759 ModifiedDT = true; 760 } 761 762 static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI, 763 DomTreeUpdater *DTU, 764 bool &ModifiedDT) { 765 Value *Src = CI->getArgOperand(0); 766 Value *Ptr = CI->getArgOperand(1); 767 Value *Mask = CI->getArgOperand(2); 768 769 auto *VecType = cast<FixedVectorType>(Src->getType()); 770 771 IRBuilder<> Builder(CI->getContext()); 772 Instruction *InsertPt = CI; 773 BasicBlock *IfBlock = CI->getParent(); 774 775 Builder.SetInsertPoint(InsertPt); 776 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 777 778 Type *EltTy = VecType->getElementType(); 779 780 unsigned VectorWidth = VecType->getNumElements(); 781 782 // Shorten the way if the mask is a vector of constants. 783 if (isConstantIntVector(Mask)) { 784 unsigned MemIndex = 0; 785 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 786 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 787 continue; 788 Value *OneElt = 789 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 790 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 791 Builder.CreateAlignedStore(OneElt, NewPtr, Align(1)); 792 ++MemIndex; 793 } 794 CI->eraseFromParent(); 795 return; 796 } 797 798 // If the mask is not v1i1, use scalar bit test operations. This generates 799 // better results on X86 at least. 800 Value *SclrMask; 801 if (VectorWidth != 1) { 802 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 803 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 804 } 805 806 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 807 // Fill the "else" block, created in the previous iteration 808 // 809 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 810 // br i1 %mask_1, label %cond.store, label %else 811 // 812 Value *Predicate; 813 if (VectorWidth != 1) { 814 Value *Mask = Builder.getInt(APInt::getOneBitSet( 815 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 816 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 817 Builder.getIntN(VectorWidth, 0)); 818 } else { 819 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 820 } 821 822 // Create "cond" block 823 // 824 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 825 // %EltAddr = getelementptr i32* %1, i32 0 826 // %store i32 %OneElt, i32* %EltAddr 827 // 828 Instruction *ThenTerm = 829 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 830 /*BranchWeights=*/nullptr, DTU); 831 832 BasicBlock *CondBlock = ThenTerm->getParent(); 833 CondBlock->setName("cond.store"); 834 835 Builder.SetInsertPoint(CondBlock->getTerminator()); 836 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 837 Builder.CreateAlignedStore(OneElt, Ptr, Align(1)); 838 839 // Move the pointer if there are more blocks to come. 840 Value *NewPtr; 841 if ((Idx + 1) != VectorWidth) 842 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 843 844 // Create "else" block, fill it in the next iteration 845 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 846 NewIfBlock->setName("else"); 847 BasicBlock *PrevIfBlock = IfBlock; 848 IfBlock = NewIfBlock; 849 850 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 851 852 // Add a PHI for the pointer if this isn't the last iteration. 853 if ((Idx + 1) != VectorWidth) { 854 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 855 PtrPhi->addIncoming(NewPtr, CondBlock); 856 PtrPhi->addIncoming(Ptr, PrevIfBlock); 857 Ptr = PtrPhi; 858 } 859 } 860 CI->eraseFromParent(); 861 862 ModifiedDT = true; 863 } 864 865 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 866 DominatorTree *DT) { 867 Optional<DomTreeUpdater> DTU; 868 if (DT) 869 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy); 870 871 bool EverMadeChange = false; 872 bool MadeChange = true; 873 auto &DL = F.getParent()->getDataLayout(); 874 while (MadeChange) { 875 MadeChange = false; 876 for (BasicBlock &BB : llvm::make_early_inc_range(F)) { 877 bool ModifiedDTOnIteration = false; 878 MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL, 879 DTU.hasValue() ? DTU.getPointer() : nullptr); 880 881 // Restart BB iteration if the dominator tree of the Function was changed 882 if (ModifiedDTOnIteration) 883 break; 884 } 885 886 EverMadeChange |= MadeChange; 887 } 888 return EverMadeChange; 889 } 890 891 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) { 892 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 893 DominatorTree *DT = nullptr; 894 if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>()) 895 DT = &DTWP->getDomTree(); 896 return runImpl(F, TTI, DT); 897 } 898 899 PreservedAnalyses 900 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) { 901 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 902 auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); 903 if (!runImpl(F, TTI, DT)) 904 return PreservedAnalyses::all(); 905 PreservedAnalyses PA; 906 PA.preserve<TargetIRAnalysis>(); 907 PA.preserve<DominatorTreeAnalysis>(); 908 return PA; 909 } 910 911 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, 912 const TargetTransformInfo &TTI, const DataLayout &DL, 913 DomTreeUpdater *DTU) { 914 bool MadeChange = false; 915 916 BasicBlock::iterator CurInstIterator = BB.begin(); 917 while (CurInstIterator != BB.end()) { 918 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++)) 919 MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU); 920 if (ModifiedDT) 921 return true; 922 } 923 924 return MadeChange; 925 } 926 927 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, 928 const TargetTransformInfo &TTI, 929 const DataLayout &DL, DomTreeUpdater *DTU) { 930 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 931 if (II) { 932 // The scalarization code below does not work for scalable vectors. 933 if (isa<ScalableVectorType>(II->getType()) || 934 any_of(II->args(), 935 [](Value *V) { return isa<ScalableVectorType>(V->getType()); })) 936 return false; 937 938 switch (II->getIntrinsicID()) { 939 default: 940 break; 941 case Intrinsic::masked_load: 942 // Scalarize unsupported vector masked load 943 if (TTI.isLegalMaskedLoad( 944 CI->getType(), 945 cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue())) 946 return false; 947 scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT); 948 return true; 949 case Intrinsic::masked_store: 950 if (TTI.isLegalMaskedStore( 951 CI->getArgOperand(0)->getType(), 952 cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue())) 953 return false; 954 scalarizeMaskedStore(DL, CI, DTU, ModifiedDT); 955 return true; 956 case Intrinsic::masked_gather: { 957 MaybeAlign MA = 958 cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue(); 959 Type *LoadTy = CI->getType(); 960 Align Alignment = DL.getValueOrABITypeAlignment(MA, 961 LoadTy->getScalarType()); 962 if (TTI.isLegalMaskedGather(LoadTy, Alignment) && 963 !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment)) 964 return false; 965 scalarizeMaskedGather(DL, CI, DTU, ModifiedDT); 966 return true; 967 } 968 case Intrinsic::masked_scatter: { 969 MaybeAlign MA = 970 cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue(); 971 Type *StoreTy = CI->getArgOperand(0)->getType(); 972 Align Alignment = DL.getValueOrABITypeAlignment(MA, 973 StoreTy->getScalarType()); 974 if (TTI.isLegalMaskedScatter(StoreTy, Alignment) && 975 !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy), 976 Alignment)) 977 return false; 978 scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT); 979 return true; 980 } 981 case Intrinsic::masked_expandload: 982 if (TTI.isLegalMaskedExpandLoad(CI->getType())) 983 return false; 984 scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT); 985 return true; 986 case Intrinsic::masked_compressstore: 987 if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) 988 return false; 989 scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT); 990 return true; 991 } 992 } 993 994 return false; 995 } 996