1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// 2 // intrinsics 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This pass replaces masked memory intrinsics - when unsupported by the target 11 // - with a chain of basic blocks, that deal with the elements one-by-one if the 12 // appropriate mask bit is set. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h" 17 #include "llvm/ADT/Twine.h" 18 #include "llvm/Analysis/DomTreeUpdater.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/IR/BasicBlock.h" 21 #include "llvm/IR/Constant.h" 22 #include "llvm/IR/Constants.h" 23 #include "llvm/IR/DerivedTypes.h" 24 #include "llvm/IR/Dominators.h" 25 #include "llvm/IR/Function.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/Instruction.h" 28 #include "llvm/IR/Instructions.h" 29 #include "llvm/IR/IntrinsicInst.h" 30 #include "llvm/IR/Type.h" 31 #include "llvm/IR/Value.h" 32 #include "llvm/InitializePasses.h" 33 #include "llvm/Pass.h" 34 #include "llvm/Support/Casting.h" 35 #include "llvm/Transforms/Scalar.h" 36 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 37 #include <cassert> 38 #include <optional> 39 40 using namespace llvm; 41 42 #define DEBUG_TYPE "scalarize-masked-mem-intrin" 43 44 namespace { 45 46 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass { 47 public: 48 static char ID; // Pass identification, replacement for typeid 49 50 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) { 51 initializeScalarizeMaskedMemIntrinLegacyPassPass( 52 *PassRegistry::getPassRegistry()); 53 } 54 55 bool runOnFunction(Function &F) override; 56 57 StringRef getPassName() const override { 58 return "Scalarize Masked Memory Intrinsics"; 59 } 60 61 void getAnalysisUsage(AnalysisUsage &AU) const override { 62 AU.addRequired<TargetTransformInfoWrapperPass>(); 63 AU.addPreserved<DominatorTreeWrapperPass>(); 64 } 65 }; 66 67 } // end anonymous namespace 68 69 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, 70 const TargetTransformInfo &TTI, const DataLayout &DL, 71 DomTreeUpdater *DTU); 72 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, 73 const TargetTransformInfo &TTI, 74 const DataLayout &DL, DomTreeUpdater *DTU); 75 76 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0; 77 78 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, 79 "Scalarize unsupported masked memory intrinsics", false, 80 false) 81 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 82 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 83 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, 84 "Scalarize unsupported masked memory intrinsics", false, 85 false) 86 87 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() { 88 return new ScalarizeMaskedMemIntrinLegacyPass(); 89 } 90 91 static bool isConstantIntVector(Value *Mask) { 92 Constant *C = dyn_cast<Constant>(Mask); 93 if (!C) 94 return false; 95 96 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements(); 97 for (unsigned i = 0; i != NumElts; ++i) { 98 Constant *CElt = C->getAggregateElement(i); 99 if (!CElt || !isa<ConstantInt>(CElt)) 100 return false; 101 } 102 103 return true; 104 } 105 106 static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, 107 unsigned Idx) { 108 return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx; 109 } 110 111 // Translate a masked load intrinsic like 112 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, 113 // <16 x i1> %mask, <16 x i32> %passthru) 114 // to a chain of basic blocks, with loading element one-by-one if 115 // the appropriate mask bit is set 116 // 117 // %1 = bitcast i8* %addr to i32* 118 // %2 = extractelement <16 x i1> %mask, i32 0 119 // br i1 %2, label %cond.load, label %else 120 // 121 // cond.load: ; preds = %0 122 // %3 = getelementptr i32* %1, i32 0 123 // %4 = load i32* %3 124 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 125 // br label %else 126 // 127 // else: ; preds = %0, %cond.load 128 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ poison, %0 ] 129 // %6 = extractelement <16 x i1> %mask, i32 1 130 // br i1 %6, label %cond.load1, label %else2 131 // 132 // cond.load1: ; preds = %else 133 // %7 = getelementptr i32* %1, i32 1 134 // %8 = load i32* %7 135 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 136 // br label %else2 137 // 138 // else2: ; preds = %else, %cond.load1 139 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] 140 // %10 = extractelement <16 x i1> %mask, i32 2 141 // br i1 %10, label %cond.load4, label %else5 142 // 143 static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI, 144 DomTreeUpdater *DTU, bool &ModifiedDT) { 145 Value *Ptr = CI->getArgOperand(0); 146 Value *Alignment = CI->getArgOperand(1); 147 Value *Mask = CI->getArgOperand(2); 148 Value *Src0 = CI->getArgOperand(3); 149 150 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); 151 VectorType *VecType = cast<FixedVectorType>(CI->getType()); 152 153 Type *EltTy = VecType->getElementType(); 154 155 IRBuilder<> Builder(CI->getContext()); 156 Instruction *InsertPt = CI; 157 BasicBlock *IfBlock = CI->getParent(); 158 159 Builder.SetInsertPoint(InsertPt); 160 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 161 162 // Short-cut if the mask is all-true. 163 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 164 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal); 165 CI->replaceAllUsesWith(NewI); 166 CI->eraseFromParent(); 167 return; 168 } 169 170 // Adjust alignment for the scalar instruction. 171 const Align AdjustedAlignVal = 172 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 173 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); 174 175 // The result vector 176 Value *VResult = Src0; 177 178 if (isConstantIntVector(Mask)) { 179 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 180 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 181 continue; 182 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx); 183 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); 184 VResult = Builder.CreateInsertElement(VResult, Load, Idx); 185 } 186 CI->replaceAllUsesWith(VResult); 187 CI->eraseFromParent(); 188 return; 189 } 190 191 // If the mask is not v1i1, use scalar bit test operations. This generates 192 // better results on X86 at least. 193 Value *SclrMask; 194 if (VectorWidth != 1) { 195 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 196 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 197 } 198 199 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 200 // Fill the "else" block, created in the previous iteration 201 // 202 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 203 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 204 // %cond = icmp ne i16 %mask_1, 0 205 // br i1 %mask_1, label %cond.load, label %else 206 // 207 Value *Predicate; 208 if (VectorWidth != 1) { 209 Value *Mask = Builder.getInt(APInt::getOneBitSet( 210 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 211 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 212 Builder.getIntN(VectorWidth, 0)); 213 } else { 214 Predicate = Builder.CreateExtractElement(Mask, Idx); 215 } 216 217 // Create "cond" block 218 // 219 // %EltAddr = getelementptr i32* %1, i32 0 220 // %Elt = load i32* %EltAddr 221 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 222 // 223 Instruction *ThenTerm = 224 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 225 /*BranchWeights=*/nullptr, DTU); 226 227 BasicBlock *CondBlock = ThenTerm->getParent(); 228 CondBlock->setName("cond.load"); 229 230 Builder.SetInsertPoint(CondBlock->getTerminator()); 231 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx); 232 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); 233 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 234 235 // Create "else" block, fill it in the next iteration 236 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 237 NewIfBlock->setName("else"); 238 BasicBlock *PrevIfBlock = IfBlock; 239 IfBlock = NewIfBlock; 240 241 // Create the phi to join the new and previous value. 242 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 243 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 244 Phi->addIncoming(NewVResult, CondBlock); 245 Phi->addIncoming(VResult, PrevIfBlock); 246 VResult = Phi; 247 } 248 249 CI->replaceAllUsesWith(VResult); 250 CI->eraseFromParent(); 251 252 ModifiedDT = true; 253 } 254 255 // Translate a masked store intrinsic, like 256 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, 257 // <16 x i1> %mask) 258 // to a chain of basic blocks, that stores element one-by-one if 259 // the appropriate mask bit is set 260 // 261 // %1 = bitcast i8* %addr to i32* 262 // %2 = extractelement <16 x i1> %mask, i32 0 263 // br i1 %2, label %cond.store, label %else 264 // 265 // cond.store: ; preds = %0 266 // %3 = extractelement <16 x i32> %val, i32 0 267 // %4 = getelementptr i32* %1, i32 0 268 // store i32 %3, i32* %4 269 // br label %else 270 // 271 // else: ; preds = %0, %cond.store 272 // %5 = extractelement <16 x i1> %mask, i32 1 273 // br i1 %5, label %cond.store1, label %else2 274 // 275 // cond.store1: ; preds = %else 276 // %6 = extractelement <16 x i32> %val, i32 1 277 // %7 = getelementptr i32* %1, i32 1 278 // store i32 %6, i32* %7 279 // br label %else2 280 // . . . 281 static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI, 282 DomTreeUpdater *DTU, bool &ModifiedDT) { 283 Value *Src = CI->getArgOperand(0); 284 Value *Ptr = CI->getArgOperand(1); 285 Value *Alignment = CI->getArgOperand(2); 286 Value *Mask = CI->getArgOperand(3); 287 288 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); 289 auto *VecType = cast<VectorType>(Src->getType()); 290 291 Type *EltTy = VecType->getElementType(); 292 293 IRBuilder<> Builder(CI->getContext()); 294 Instruction *InsertPt = CI; 295 Builder.SetInsertPoint(InsertPt); 296 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 297 298 // Short-cut if the mask is all-true. 299 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 300 Builder.CreateAlignedStore(Src, Ptr, AlignVal); 301 CI->eraseFromParent(); 302 return; 303 } 304 305 // Adjust alignment for the scalar instruction. 306 const Align AdjustedAlignVal = 307 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 308 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); 309 310 if (isConstantIntVector(Mask)) { 311 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 312 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 313 continue; 314 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 315 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx); 316 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); 317 } 318 CI->eraseFromParent(); 319 return; 320 } 321 322 // If the mask is not v1i1, use scalar bit test operations. This generates 323 // better results on X86 at least. 324 Value *SclrMask; 325 if (VectorWidth != 1) { 326 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 327 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 328 } 329 330 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 331 // Fill the "else" block, created in the previous iteration 332 // 333 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 334 // %cond = icmp ne i16 %mask_1, 0 335 // br i1 %mask_1, label %cond.store, label %else 336 // 337 Value *Predicate; 338 if (VectorWidth != 1) { 339 Value *Mask = Builder.getInt(APInt::getOneBitSet( 340 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 341 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 342 Builder.getIntN(VectorWidth, 0)); 343 } else { 344 Predicate = Builder.CreateExtractElement(Mask, Idx); 345 } 346 347 // Create "cond" block 348 // 349 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 350 // %EltAddr = getelementptr i32* %1, i32 0 351 // %store i32 %OneElt, i32* %EltAddr 352 // 353 Instruction *ThenTerm = 354 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 355 /*BranchWeights=*/nullptr, DTU); 356 357 BasicBlock *CondBlock = ThenTerm->getParent(); 358 CondBlock->setName("cond.store"); 359 360 Builder.SetInsertPoint(CondBlock->getTerminator()); 361 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 362 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, Idx); 363 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); 364 365 // Create "else" block, fill it in the next iteration 366 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 367 NewIfBlock->setName("else"); 368 369 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 370 } 371 CI->eraseFromParent(); 372 373 ModifiedDT = true; 374 } 375 376 // Translate a masked gather intrinsic like 377 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, 378 // <16 x i1> %Mask, <16 x i32> %Src) 379 // to a chain of basic blocks, with loading element one-by-one if 380 // the appropriate mask bit is set 381 // 382 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind 383 // %Mask0 = extractelement <16 x i1> %Mask, i32 0 384 // br i1 %Mask0, label %cond.load, label %else 385 // 386 // cond.load: 387 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 388 // %Load0 = load i32, i32* %Ptr0, align 4 389 // %Res0 = insertelement <16 x i32> poison, i32 %Load0, i32 0 390 // br label %else 391 // 392 // else: 393 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [poison, %0] 394 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 395 // br i1 %Mask1, label %cond.load1, label %else2 396 // 397 // cond.load1: 398 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 399 // %Load1 = load i32, i32* %Ptr1, align 4 400 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 401 // br label %else2 402 // . . . 403 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src 404 // ret <16 x i32> %Result 405 static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI, 406 DomTreeUpdater *DTU, bool &ModifiedDT) { 407 Value *Ptrs = CI->getArgOperand(0); 408 Value *Alignment = CI->getArgOperand(1); 409 Value *Mask = CI->getArgOperand(2); 410 Value *Src0 = CI->getArgOperand(3); 411 412 auto *VecType = cast<FixedVectorType>(CI->getType()); 413 Type *EltTy = VecType->getElementType(); 414 415 IRBuilder<> Builder(CI->getContext()); 416 Instruction *InsertPt = CI; 417 BasicBlock *IfBlock = CI->getParent(); 418 Builder.SetInsertPoint(InsertPt); 419 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); 420 421 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 422 423 // The result vector 424 Value *VResult = Src0; 425 unsigned VectorWidth = VecType->getNumElements(); 426 427 // Shorten the way if the mask is a vector of constants. 428 if (isConstantIntVector(Mask)) { 429 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 430 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 431 continue; 432 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 433 LoadInst *Load = 434 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 435 VResult = 436 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 437 } 438 CI->replaceAllUsesWith(VResult); 439 CI->eraseFromParent(); 440 return; 441 } 442 443 // If the mask is not v1i1, use scalar bit test operations. This generates 444 // better results on X86 at least. 445 Value *SclrMask; 446 if (VectorWidth != 1) { 447 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 448 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 449 } 450 451 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 452 // Fill the "else" block, created in the previous iteration 453 // 454 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 455 // %cond = icmp ne i16 %mask_1, 0 456 // br i1 %Mask1, label %cond.load, label %else 457 // 458 459 Value *Predicate; 460 if (VectorWidth != 1) { 461 Value *Mask = Builder.getInt(APInt::getOneBitSet( 462 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 463 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 464 Builder.getIntN(VectorWidth, 0)); 465 } else { 466 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 467 } 468 469 // Create "cond" block 470 // 471 // %EltAddr = getelementptr i32* %1, i32 0 472 // %Elt = load i32* %EltAddr 473 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 474 // 475 Instruction *ThenTerm = 476 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 477 /*BranchWeights=*/nullptr, DTU); 478 479 BasicBlock *CondBlock = ThenTerm->getParent(); 480 CondBlock->setName("cond.load"); 481 482 Builder.SetInsertPoint(CondBlock->getTerminator()); 483 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 484 LoadInst *Load = 485 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 486 Value *NewVResult = 487 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 488 489 // Create "else" block, fill it in the next iteration 490 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 491 NewIfBlock->setName("else"); 492 BasicBlock *PrevIfBlock = IfBlock; 493 IfBlock = NewIfBlock; 494 495 // Create the phi to join the new and previous value. 496 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 497 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 498 Phi->addIncoming(NewVResult, CondBlock); 499 Phi->addIncoming(VResult, PrevIfBlock); 500 VResult = Phi; 501 } 502 503 CI->replaceAllUsesWith(VResult); 504 CI->eraseFromParent(); 505 506 ModifiedDT = true; 507 } 508 509 // Translate a masked scatter intrinsic, like 510 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, 511 // <16 x i1> %Mask) 512 // to a chain of basic blocks, that stores element one-by-one if 513 // the appropriate mask bit is set. 514 // 515 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind 516 // %Mask0 = extractelement <16 x i1> %Mask, i32 0 517 // br i1 %Mask0, label %cond.store, label %else 518 // 519 // cond.store: 520 // %Elt0 = extractelement <16 x i32> %Src, i32 0 521 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 522 // store i32 %Elt0, i32* %Ptr0, align 4 523 // br label %else 524 // 525 // else: 526 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 527 // br i1 %Mask1, label %cond.store1, label %else2 528 // 529 // cond.store1: 530 // %Elt1 = extractelement <16 x i32> %Src, i32 1 531 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 532 // store i32 %Elt1, i32* %Ptr1, align 4 533 // br label %else2 534 // . . . 535 static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI, 536 DomTreeUpdater *DTU, bool &ModifiedDT) { 537 Value *Src = CI->getArgOperand(0); 538 Value *Ptrs = CI->getArgOperand(1); 539 Value *Alignment = CI->getArgOperand(2); 540 Value *Mask = CI->getArgOperand(3); 541 542 auto *SrcFVTy = cast<FixedVectorType>(Src->getType()); 543 544 assert( 545 isa<VectorType>(Ptrs->getType()) && 546 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) && 547 "Vector of pointers is expected in masked scatter intrinsic"); 548 549 IRBuilder<> Builder(CI->getContext()); 550 Instruction *InsertPt = CI; 551 Builder.SetInsertPoint(InsertPt); 552 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 553 554 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); 555 unsigned VectorWidth = SrcFVTy->getNumElements(); 556 557 // Shorten the way if the mask is a vector of constants. 558 if (isConstantIntVector(Mask)) { 559 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 560 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 561 continue; 562 Value *OneElt = 563 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 564 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 565 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 566 } 567 CI->eraseFromParent(); 568 return; 569 } 570 571 // If the mask is not v1i1, use scalar bit test operations. This generates 572 // better results on X86 at least. 573 Value *SclrMask; 574 if (VectorWidth != 1) { 575 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 576 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 577 } 578 579 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 580 // Fill the "else" block, created in the previous iteration 581 // 582 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 583 // %cond = icmp ne i16 %mask_1, 0 584 // br i1 %Mask1, label %cond.store, label %else 585 // 586 Value *Predicate; 587 if (VectorWidth != 1) { 588 Value *Mask = Builder.getInt(APInt::getOneBitSet( 589 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 590 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 591 Builder.getIntN(VectorWidth, 0)); 592 } else { 593 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 594 } 595 596 // Create "cond" block 597 // 598 // %Elt1 = extractelement <16 x i32> %Src, i32 1 599 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 600 // %store i32 %Elt1, i32* %Ptr1 601 // 602 Instruction *ThenTerm = 603 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 604 /*BranchWeights=*/nullptr, DTU); 605 606 BasicBlock *CondBlock = ThenTerm->getParent(); 607 CondBlock->setName("cond.store"); 608 609 Builder.SetInsertPoint(CondBlock->getTerminator()); 610 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 611 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 612 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 613 614 // Create "else" block, fill it in the next iteration 615 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 616 NewIfBlock->setName("else"); 617 618 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 619 } 620 CI->eraseFromParent(); 621 622 ModifiedDT = true; 623 } 624 625 static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI, 626 DomTreeUpdater *DTU, bool &ModifiedDT) { 627 Value *Ptr = CI->getArgOperand(0); 628 Value *Mask = CI->getArgOperand(1); 629 Value *PassThru = CI->getArgOperand(2); 630 631 auto *VecType = cast<FixedVectorType>(CI->getType()); 632 633 Type *EltTy = VecType->getElementType(); 634 635 IRBuilder<> Builder(CI->getContext()); 636 Instruction *InsertPt = CI; 637 BasicBlock *IfBlock = CI->getParent(); 638 639 Builder.SetInsertPoint(InsertPt); 640 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 641 642 unsigned VectorWidth = VecType->getNumElements(); 643 644 // The result vector 645 Value *VResult = PassThru; 646 647 // Shorten the way if the mask is a vector of constants. 648 // Create a build_vector pattern, with loads/poisons as necessary and then 649 // shuffle blend with the pass through value. 650 if (isConstantIntVector(Mask)) { 651 unsigned MemIndex = 0; 652 VResult = PoisonValue::get(VecType); 653 SmallVector<int, 16> ShuffleMask(VectorWidth, PoisonMaskElem); 654 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 655 Value *InsertElt; 656 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) { 657 InsertElt = PoisonValue::get(EltTy); 658 ShuffleMask[Idx] = Idx + VectorWidth; 659 } else { 660 Value *NewPtr = 661 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 662 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1), 663 "Load" + Twine(Idx)); 664 ShuffleMask[Idx] = Idx; 665 ++MemIndex; 666 } 667 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx, 668 "Res" + Twine(Idx)); 669 } 670 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask); 671 CI->replaceAllUsesWith(VResult); 672 CI->eraseFromParent(); 673 return; 674 } 675 676 // If the mask is not v1i1, use scalar bit test operations. This generates 677 // better results on X86 at least. 678 Value *SclrMask; 679 if (VectorWidth != 1) { 680 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 681 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 682 } 683 684 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 685 // Fill the "else" block, created in the previous iteration 686 // 687 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 688 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 689 // br i1 %mask_1, label %cond.load, label %else 690 // 691 692 Value *Predicate; 693 if (VectorWidth != 1) { 694 Value *Mask = Builder.getInt(APInt::getOneBitSet( 695 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 696 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 697 Builder.getIntN(VectorWidth, 0)); 698 } else { 699 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 700 } 701 702 // Create "cond" block 703 // 704 // %EltAddr = getelementptr i32* %1, i32 0 705 // %Elt = load i32* %EltAddr 706 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 707 // 708 Instruction *ThenTerm = 709 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 710 /*BranchWeights=*/nullptr, DTU); 711 712 BasicBlock *CondBlock = ThenTerm->getParent(); 713 CondBlock->setName("cond.load"); 714 715 Builder.SetInsertPoint(CondBlock->getTerminator()); 716 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1)); 717 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 718 719 // Move the pointer if there are more blocks to come. 720 Value *NewPtr; 721 if ((Idx + 1) != VectorWidth) 722 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 723 724 // Create "else" block, fill it in the next iteration 725 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 726 NewIfBlock->setName("else"); 727 BasicBlock *PrevIfBlock = IfBlock; 728 IfBlock = NewIfBlock; 729 730 // Create the phi to join the new and previous value. 731 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 732 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 733 ResultPhi->addIncoming(NewVResult, CondBlock); 734 ResultPhi->addIncoming(VResult, PrevIfBlock); 735 VResult = ResultPhi; 736 737 // Add a PHI for the pointer if this isn't the last iteration. 738 if ((Idx + 1) != VectorWidth) { 739 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 740 PtrPhi->addIncoming(NewPtr, CondBlock); 741 PtrPhi->addIncoming(Ptr, PrevIfBlock); 742 Ptr = PtrPhi; 743 } 744 } 745 746 CI->replaceAllUsesWith(VResult); 747 CI->eraseFromParent(); 748 749 ModifiedDT = true; 750 } 751 752 static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI, 753 DomTreeUpdater *DTU, 754 bool &ModifiedDT) { 755 Value *Src = CI->getArgOperand(0); 756 Value *Ptr = CI->getArgOperand(1); 757 Value *Mask = CI->getArgOperand(2); 758 759 auto *VecType = cast<FixedVectorType>(Src->getType()); 760 761 IRBuilder<> Builder(CI->getContext()); 762 Instruction *InsertPt = CI; 763 BasicBlock *IfBlock = CI->getParent(); 764 765 Builder.SetInsertPoint(InsertPt); 766 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 767 768 Type *EltTy = VecType->getElementType(); 769 770 unsigned VectorWidth = VecType->getNumElements(); 771 772 // Shorten the way if the mask is a vector of constants. 773 if (isConstantIntVector(Mask)) { 774 unsigned MemIndex = 0; 775 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 776 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 777 continue; 778 Value *OneElt = 779 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 780 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 781 Builder.CreateAlignedStore(OneElt, NewPtr, Align(1)); 782 ++MemIndex; 783 } 784 CI->eraseFromParent(); 785 return; 786 } 787 788 // If the mask is not v1i1, use scalar bit test operations. This generates 789 // better results on X86 at least. 790 Value *SclrMask; 791 if (VectorWidth != 1) { 792 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 793 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 794 } 795 796 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 797 // Fill the "else" block, created in the previous iteration 798 // 799 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 800 // br i1 %mask_1, label %cond.store, label %else 801 // 802 Value *Predicate; 803 if (VectorWidth != 1) { 804 Value *Mask = Builder.getInt(APInt::getOneBitSet( 805 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 806 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 807 Builder.getIntN(VectorWidth, 0)); 808 } else { 809 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 810 } 811 812 // Create "cond" block 813 // 814 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 815 // %EltAddr = getelementptr i32* %1, i32 0 816 // %store i32 %OneElt, i32* %EltAddr 817 // 818 Instruction *ThenTerm = 819 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 820 /*BranchWeights=*/nullptr, DTU); 821 822 BasicBlock *CondBlock = ThenTerm->getParent(); 823 CondBlock->setName("cond.store"); 824 825 Builder.SetInsertPoint(CondBlock->getTerminator()); 826 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 827 Builder.CreateAlignedStore(OneElt, Ptr, Align(1)); 828 829 // Move the pointer if there are more blocks to come. 830 Value *NewPtr; 831 if ((Idx + 1) != VectorWidth) 832 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 833 834 // Create "else" block, fill it in the next iteration 835 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 836 NewIfBlock->setName("else"); 837 BasicBlock *PrevIfBlock = IfBlock; 838 IfBlock = NewIfBlock; 839 840 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 841 842 // Add a PHI for the pointer if this isn't the last iteration. 843 if ((Idx + 1) != VectorWidth) { 844 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 845 PtrPhi->addIncoming(NewPtr, CondBlock); 846 PtrPhi->addIncoming(Ptr, PrevIfBlock); 847 Ptr = PtrPhi; 848 } 849 } 850 CI->eraseFromParent(); 851 852 ModifiedDT = true; 853 } 854 855 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 856 DominatorTree *DT) { 857 std::optional<DomTreeUpdater> DTU; 858 if (DT) 859 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy); 860 861 bool EverMadeChange = false; 862 bool MadeChange = true; 863 auto &DL = F.getParent()->getDataLayout(); 864 while (MadeChange) { 865 MadeChange = false; 866 for (BasicBlock &BB : llvm::make_early_inc_range(F)) { 867 bool ModifiedDTOnIteration = false; 868 MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL, 869 DTU ? &*DTU : nullptr); 870 871 // Restart BB iteration if the dominator tree of the Function was changed 872 if (ModifiedDTOnIteration) 873 break; 874 } 875 876 EverMadeChange |= MadeChange; 877 } 878 return EverMadeChange; 879 } 880 881 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) { 882 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 883 DominatorTree *DT = nullptr; 884 if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>()) 885 DT = &DTWP->getDomTree(); 886 return runImpl(F, TTI, DT); 887 } 888 889 PreservedAnalyses 890 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) { 891 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 892 auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); 893 if (!runImpl(F, TTI, DT)) 894 return PreservedAnalyses::all(); 895 PreservedAnalyses PA; 896 PA.preserve<TargetIRAnalysis>(); 897 PA.preserve<DominatorTreeAnalysis>(); 898 return PA; 899 } 900 901 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, 902 const TargetTransformInfo &TTI, const DataLayout &DL, 903 DomTreeUpdater *DTU) { 904 bool MadeChange = false; 905 906 BasicBlock::iterator CurInstIterator = BB.begin(); 907 while (CurInstIterator != BB.end()) { 908 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++)) 909 MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU); 910 if (ModifiedDT) 911 return true; 912 } 913 914 return MadeChange; 915 } 916 917 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, 918 const TargetTransformInfo &TTI, 919 const DataLayout &DL, DomTreeUpdater *DTU) { 920 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 921 if (II) { 922 // The scalarization code below does not work for scalable vectors. 923 if (isa<ScalableVectorType>(II->getType()) || 924 any_of(II->args(), 925 [](Value *V) { return isa<ScalableVectorType>(V->getType()); })) 926 return false; 927 928 switch (II->getIntrinsicID()) { 929 default: 930 break; 931 case Intrinsic::masked_load: 932 // Scalarize unsupported vector masked load 933 if (TTI.isLegalMaskedLoad( 934 CI->getType(), 935 cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue())) 936 return false; 937 scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT); 938 return true; 939 case Intrinsic::masked_store: 940 if (TTI.isLegalMaskedStore( 941 CI->getArgOperand(0)->getType(), 942 cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue())) 943 return false; 944 scalarizeMaskedStore(DL, CI, DTU, ModifiedDT); 945 return true; 946 case Intrinsic::masked_gather: { 947 MaybeAlign MA = 948 cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue(); 949 Type *LoadTy = CI->getType(); 950 Align Alignment = DL.getValueOrABITypeAlignment(MA, 951 LoadTy->getScalarType()); 952 if (TTI.isLegalMaskedGather(LoadTy, Alignment) && 953 !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment)) 954 return false; 955 scalarizeMaskedGather(DL, CI, DTU, ModifiedDT); 956 return true; 957 } 958 case Intrinsic::masked_scatter: { 959 MaybeAlign MA = 960 cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue(); 961 Type *StoreTy = CI->getArgOperand(0)->getType(); 962 Align Alignment = DL.getValueOrABITypeAlignment(MA, 963 StoreTy->getScalarType()); 964 if (TTI.isLegalMaskedScatter(StoreTy, Alignment) && 965 !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy), 966 Alignment)) 967 return false; 968 scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT); 969 return true; 970 } 971 case Intrinsic::masked_expandload: 972 if (TTI.isLegalMaskedExpandLoad(CI->getType())) 973 return false; 974 scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT); 975 return true; 976 case Intrinsic::masked_compressstore: 977 if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) 978 return false; 979 scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT); 980 return true; 981 } 982 } 983 984 return false; 985 } 986