1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// 2 // intrinsics 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This pass replaces masked memory intrinsics - when unsupported by the target 11 // - with a chain of basic blocks, that deal with the elements one-by-one if the 12 // appropriate mask bit is set. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h" 17 #include "llvm/ADT/Twine.h" 18 #include "llvm/Analysis/DomTreeUpdater.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/IR/BasicBlock.h" 21 #include "llvm/IR/Constant.h" 22 #include "llvm/IR/Constants.h" 23 #include "llvm/IR/DerivedTypes.h" 24 #include "llvm/IR/Dominators.h" 25 #include "llvm/IR/Function.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/Instruction.h" 28 #include "llvm/IR/Instructions.h" 29 #include "llvm/IR/IntrinsicInst.h" 30 #include "llvm/IR/Type.h" 31 #include "llvm/IR/Value.h" 32 #include "llvm/InitializePasses.h" 33 #include "llvm/Pass.h" 34 #include "llvm/Support/Casting.h" 35 #include "llvm/Transforms/Scalar.h" 36 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 37 #include <cassert> 38 #include <optional> 39 40 using namespace llvm; 41 42 #define DEBUG_TYPE "scalarize-masked-mem-intrin" 43 44 namespace { 45 46 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass { 47 public: 48 static char ID; // Pass identification, replacement for typeid 49 50 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) { 51 initializeScalarizeMaskedMemIntrinLegacyPassPass( 52 *PassRegistry::getPassRegistry()); 53 } 54 55 bool runOnFunction(Function &F) override; 56 57 StringRef getPassName() const override { 58 return "Scalarize Masked Memory Intrinsics"; 59 } 60 61 void getAnalysisUsage(AnalysisUsage &AU) const override { 62 AU.addRequired<TargetTransformInfoWrapperPass>(); 63 AU.addPreserved<DominatorTreeWrapperPass>(); 64 } 65 }; 66 67 } // end anonymous namespace 68 69 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, 70 const TargetTransformInfo &TTI, const DataLayout &DL, 71 DomTreeUpdater *DTU); 72 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, 73 const TargetTransformInfo &TTI, 74 const DataLayout &DL, DomTreeUpdater *DTU); 75 76 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0; 77 78 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, 79 "Scalarize unsupported masked memory intrinsics", false, 80 false) 81 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 82 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 83 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, 84 "Scalarize unsupported masked memory intrinsics", false, 85 false) 86 87 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() { 88 return new ScalarizeMaskedMemIntrinLegacyPass(); 89 } 90 91 static bool isConstantIntVector(Value *Mask) { 92 Constant *C = dyn_cast<Constant>(Mask); 93 if (!C) 94 return false; 95 96 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements(); 97 for (unsigned i = 0; i != NumElts; ++i) { 98 Constant *CElt = C->getAggregateElement(i); 99 if (!CElt || !isa<ConstantInt>(CElt)) 100 return false; 101 } 102 103 return true; 104 } 105 106 static unsigned adjustForEndian(const DataLayout &DL, unsigned VectorWidth, 107 unsigned Idx) { 108 return DL.isBigEndian() ? VectorWidth - 1 - Idx : Idx; 109 } 110 111 // Translate a masked load intrinsic like 112 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, 113 // <16 x i1> %mask, <16 x i32> %passthru) 114 // to a chain of basic blocks, with loading element one-by-one if 115 // the appropriate mask bit is set 116 // 117 // %1 = bitcast i8* %addr to i32* 118 // %2 = extractelement <16 x i1> %mask, i32 0 119 // br i1 %2, label %cond.load, label %else 120 // 121 // cond.load: ; preds = %0 122 // %3 = getelementptr i32* %1, i32 0 123 // %4 = load i32* %3 124 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 125 // br label %else 126 // 127 // else: ; preds = %0, %cond.load 128 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ] 129 // %6 = extractelement <16 x i1> %mask, i32 1 130 // br i1 %6, label %cond.load1, label %else2 131 // 132 // cond.load1: ; preds = %else 133 // %7 = getelementptr i32* %1, i32 1 134 // %8 = load i32* %7 135 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 136 // br label %else2 137 // 138 // else2: ; preds = %else, %cond.load1 139 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] 140 // %10 = extractelement <16 x i1> %mask, i32 2 141 // br i1 %10, label %cond.load4, label %else5 142 // 143 static void scalarizeMaskedLoad(const DataLayout &DL, CallInst *CI, 144 DomTreeUpdater *DTU, bool &ModifiedDT) { 145 Value *Ptr = CI->getArgOperand(0); 146 Value *Alignment = CI->getArgOperand(1); 147 Value *Mask = CI->getArgOperand(2); 148 Value *Src0 = CI->getArgOperand(3); 149 150 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); 151 VectorType *VecType = cast<FixedVectorType>(CI->getType()); 152 153 Type *EltTy = VecType->getElementType(); 154 155 IRBuilder<> Builder(CI->getContext()); 156 Instruction *InsertPt = CI; 157 BasicBlock *IfBlock = CI->getParent(); 158 159 Builder.SetInsertPoint(InsertPt); 160 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 161 162 // Short-cut if the mask is all-true. 163 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 164 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal); 165 CI->replaceAllUsesWith(NewI); 166 CI->eraseFromParent(); 167 return; 168 } 169 170 // Adjust alignment for the scalar instruction. 171 const Align AdjustedAlignVal = 172 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 173 // Bitcast %addr from i8* to EltTy* 174 Type *NewPtrType = 175 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); 176 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 177 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); 178 179 // The result vector 180 Value *VResult = Src0; 181 182 if (isConstantIntVector(Mask)) { 183 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 184 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 185 continue; 186 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 187 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); 188 VResult = Builder.CreateInsertElement(VResult, Load, Idx); 189 } 190 CI->replaceAllUsesWith(VResult); 191 CI->eraseFromParent(); 192 return; 193 } 194 195 // If the mask is not v1i1, use scalar bit test operations. This generates 196 // better results on X86 at least. 197 Value *SclrMask; 198 if (VectorWidth != 1) { 199 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 200 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 201 } 202 203 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 204 // Fill the "else" block, created in the previous iteration 205 // 206 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 207 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 208 // %cond = icmp ne i16 %mask_1, 0 209 // br i1 %mask_1, label %cond.load, label %else 210 // 211 Value *Predicate; 212 if (VectorWidth != 1) { 213 Value *Mask = Builder.getInt(APInt::getOneBitSet( 214 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 215 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 216 Builder.getIntN(VectorWidth, 0)); 217 } else { 218 Predicate = Builder.CreateExtractElement(Mask, Idx); 219 } 220 221 // Create "cond" block 222 // 223 // %EltAddr = getelementptr i32* %1, i32 0 224 // %Elt = load i32* %EltAddr 225 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 226 // 227 Instruction *ThenTerm = 228 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 229 /*BranchWeights=*/nullptr, DTU); 230 231 BasicBlock *CondBlock = ThenTerm->getParent(); 232 CondBlock->setName("cond.load"); 233 234 Builder.SetInsertPoint(CondBlock->getTerminator()); 235 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 236 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); 237 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 238 239 // Create "else" block, fill it in the next iteration 240 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 241 NewIfBlock->setName("else"); 242 BasicBlock *PrevIfBlock = IfBlock; 243 IfBlock = NewIfBlock; 244 245 // Create the phi to join the new and previous value. 246 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 247 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 248 Phi->addIncoming(NewVResult, CondBlock); 249 Phi->addIncoming(VResult, PrevIfBlock); 250 VResult = Phi; 251 } 252 253 CI->replaceAllUsesWith(VResult); 254 CI->eraseFromParent(); 255 256 ModifiedDT = true; 257 } 258 259 // Translate a masked store intrinsic, like 260 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, 261 // <16 x i1> %mask) 262 // to a chain of basic blocks, that stores element one-by-one if 263 // the appropriate mask bit is set 264 // 265 // %1 = bitcast i8* %addr to i32* 266 // %2 = extractelement <16 x i1> %mask, i32 0 267 // br i1 %2, label %cond.store, label %else 268 // 269 // cond.store: ; preds = %0 270 // %3 = extractelement <16 x i32> %val, i32 0 271 // %4 = getelementptr i32* %1, i32 0 272 // store i32 %3, i32* %4 273 // br label %else 274 // 275 // else: ; preds = %0, %cond.store 276 // %5 = extractelement <16 x i1> %mask, i32 1 277 // br i1 %5, label %cond.store1, label %else2 278 // 279 // cond.store1: ; preds = %else 280 // %6 = extractelement <16 x i32> %val, i32 1 281 // %7 = getelementptr i32* %1, i32 1 282 // store i32 %6, i32* %7 283 // br label %else2 284 // . . . 285 static void scalarizeMaskedStore(const DataLayout &DL, CallInst *CI, 286 DomTreeUpdater *DTU, bool &ModifiedDT) { 287 Value *Src = CI->getArgOperand(0); 288 Value *Ptr = CI->getArgOperand(1); 289 Value *Alignment = CI->getArgOperand(2); 290 Value *Mask = CI->getArgOperand(3); 291 292 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); 293 auto *VecType = cast<VectorType>(Src->getType()); 294 295 Type *EltTy = VecType->getElementType(); 296 297 IRBuilder<> Builder(CI->getContext()); 298 Instruction *InsertPt = CI; 299 Builder.SetInsertPoint(InsertPt); 300 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 301 302 // Short-cut if the mask is all-true. 303 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 304 Builder.CreateAlignedStore(Src, Ptr, AlignVal); 305 CI->eraseFromParent(); 306 return; 307 } 308 309 // Adjust alignment for the scalar instruction. 310 const Align AdjustedAlignVal = 311 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 312 // Bitcast %addr from i8* to EltTy* 313 Type *NewPtrType = 314 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); 315 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 316 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); 317 318 if (isConstantIntVector(Mask)) { 319 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 320 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 321 continue; 322 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 323 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 324 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); 325 } 326 CI->eraseFromParent(); 327 return; 328 } 329 330 // If the mask is not v1i1, use scalar bit test operations. This generates 331 // better results on X86 at least. 332 Value *SclrMask; 333 if (VectorWidth != 1) { 334 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 335 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 336 } 337 338 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 339 // Fill the "else" block, created in the previous iteration 340 // 341 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 342 // %cond = icmp ne i16 %mask_1, 0 343 // br i1 %mask_1, label %cond.store, label %else 344 // 345 Value *Predicate; 346 if (VectorWidth != 1) { 347 Value *Mask = Builder.getInt(APInt::getOneBitSet( 348 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 349 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 350 Builder.getIntN(VectorWidth, 0)); 351 } else { 352 Predicate = Builder.CreateExtractElement(Mask, Idx); 353 } 354 355 // Create "cond" block 356 // 357 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 358 // %EltAddr = getelementptr i32* %1, i32 0 359 // %store i32 %OneElt, i32* %EltAddr 360 // 361 Instruction *ThenTerm = 362 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 363 /*BranchWeights=*/nullptr, DTU); 364 365 BasicBlock *CondBlock = ThenTerm->getParent(); 366 CondBlock->setName("cond.store"); 367 368 Builder.SetInsertPoint(CondBlock->getTerminator()); 369 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 370 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 371 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); 372 373 // Create "else" block, fill it in the next iteration 374 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 375 NewIfBlock->setName("else"); 376 377 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 378 } 379 CI->eraseFromParent(); 380 381 ModifiedDT = true; 382 } 383 384 // Translate a masked gather intrinsic like 385 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, 386 // <16 x i1> %Mask, <16 x i32> %Src) 387 // to a chain of basic blocks, with loading element one-by-one if 388 // the appropriate mask bit is set 389 // 390 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind 391 // %Mask0 = extractelement <16 x i1> %Mask, i32 0 392 // br i1 %Mask0, label %cond.load, label %else 393 // 394 // cond.load: 395 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 396 // %Load0 = load i32, i32* %Ptr0, align 4 397 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0 398 // br label %else 399 // 400 // else: 401 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0] 402 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 403 // br i1 %Mask1, label %cond.load1, label %else2 404 // 405 // cond.load1: 406 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 407 // %Load1 = load i32, i32* %Ptr1, align 4 408 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 409 // br label %else2 410 // . . . 411 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src 412 // ret <16 x i32> %Result 413 static void scalarizeMaskedGather(const DataLayout &DL, CallInst *CI, 414 DomTreeUpdater *DTU, bool &ModifiedDT) { 415 Value *Ptrs = CI->getArgOperand(0); 416 Value *Alignment = CI->getArgOperand(1); 417 Value *Mask = CI->getArgOperand(2); 418 Value *Src0 = CI->getArgOperand(3); 419 420 auto *VecType = cast<FixedVectorType>(CI->getType()); 421 Type *EltTy = VecType->getElementType(); 422 423 IRBuilder<> Builder(CI->getContext()); 424 Instruction *InsertPt = CI; 425 BasicBlock *IfBlock = CI->getParent(); 426 Builder.SetInsertPoint(InsertPt); 427 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); 428 429 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 430 431 // The result vector 432 Value *VResult = Src0; 433 unsigned VectorWidth = VecType->getNumElements(); 434 435 // Shorten the way if the mask is a vector of constants. 436 if (isConstantIntVector(Mask)) { 437 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 438 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 439 continue; 440 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 441 LoadInst *Load = 442 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 443 VResult = 444 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 445 } 446 CI->replaceAllUsesWith(VResult); 447 CI->eraseFromParent(); 448 return; 449 } 450 451 // If the mask is not v1i1, use scalar bit test operations. This generates 452 // better results on X86 at least. 453 Value *SclrMask; 454 if (VectorWidth != 1) { 455 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 456 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 457 } 458 459 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 460 // Fill the "else" block, created in the previous iteration 461 // 462 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 463 // %cond = icmp ne i16 %mask_1, 0 464 // br i1 %Mask1, label %cond.load, label %else 465 // 466 467 Value *Predicate; 468 if (VectorWidth != 1) { 469 Value *Mask = Builder.getInt(APInt::getOneBitSet( 470 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 471 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 472 Builder.getIntN(VectorWidth, 0)); 473 } else { 474 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 475 } 476 477 // Create "cond" block 478 // 479 // %EltAddr = getelementptr i32* %1, i32 0 480 // %Elt = load i32* %EltAddr 481 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 482 // 483 Instruction *ThenTerm = 484 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 485 /*BranchWeights=*/nullptr, DTU); 486 487 BasicBlock *CondBlock = ThenTerm->getParent(); 488 CondBlock->setName("cond.load"); 489 490 Builder.SetInsertPoint(CondBlock->getTerminator()); 491 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 492 LoadInst *Load = 493 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 494 Value *NewVResult = 495 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 496 497 // Create "else" block, fill it in the next iteration 498 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 499 NewIfBlock->setName("else"); 500 BasicBlock *PrevIfBlock = IfBlock; 501 IfBlock = NewIfBlock; 502 503 // Create the phi to join the new and previous value. 504 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 505 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 506 Phi->addIncoming(NewVResult, CondBlock); 507 Phi->addIncoming(VResult, PrevIfBlock); 508 VResult = Phi; 509 } 510 511 CI->replaceAllUsesWith(VResult); 512 CI->eraseFromParent(); 513 514 ModifiedDT = true; 515 } 516 517 // Translate a masked scatter intrinsic, like 518 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, 519 // <16 x i1> %Mask) 520 // to a chain of basic blocks, that stores element one-by-one if 521 // the appropriate mask bit is set. 522 // 523 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind 524 // %Mask0 = extractelement <16 x i1> %Mask, i32 0 525 // br i1 %Mask0, label %cond.store, label %else 526 // 527 // cond.store: 528 // %Elt0 = extractelement <16 x i32> %Src, i32 0 529 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 530 // store i32 %Elt0, i32* %Ptr0, align 4 531 // br label %else 532 // 533 // else: 534 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 535 // br i1 %Mask1, label %cond.store1, label %else2 536 // 537 // cond.store1: 538 // %Elt1 = extractelement <16 x i32> %Src, i32 1 539 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 540 // store i32 %Elt1, i32* %Ptr1, align 4 541 // br label %else2 542 // . . . 543 static void scalarizeMaskedScatter(const DataLayout &DL, CallInst *CI, 544 DomTreeUpdater *DTU, bool &ModifiedDT) { 545 Value *Src = CI->getArgOperand(0); 546 Value *Ptrs = CI->getArgOperand(1); 547 Value *Alignment = CI->getArgOperand(2); 548 Value *Mask = CI->getArgOperand(3); 549 550 auto *SrcFVTy = cast<FixedVectorType>(Src->getType()); 551 552 assert( 553 isa<VectorType>(Ptrs->getType()) && 554 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) && 555 "Vector of pointers is expected in masked scatter intrinsic"); 556 557 IRBuilder<> Builder(CI->getContext()); 558 Instruction *InsertPt = CI; 559 Builder.SetInsertPoint(InsertPt); 560 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 561 562 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); 563 unsigned VectorWidth = SrcFVTy->getNumElements(); 564 565 // Shorten the way if the mask is a vector of constants. 566 if (isConstantIntVector(Mask)) { 567 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 568 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 569 continue; 570 Value *OneElt = 571 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 572 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 573 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 574 } 575 CI->eraseFromParent(); 576 return; 577 } 578 579 // If the mask is not v1i1, use scalar bit test operations. This generates 580 // better results on X86 at least. 581 Value *SclrMask; 582 if (VectorWidth != 1) { 583 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 584 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 585 } 586 587 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 588 // Fill the "else" block, created in the previous iteration 589 // 590 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 591 // %cond = icmp ne i16 %mask_1, 0 592 // br i1 %Mask1, label %cond.store, label %else 593 // 594 Value *Predicate; 595 if (VectorWidth != 1) { 596 Value *Mask = Builder.getInt(APInt::getOneBitSet( 597 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 598 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 599 Builder.getIntN(VectorWidth, 0)); 600 } else { 601 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 602 } 603 604 // Create "cond" block 605 // 606 // %Elt1 = extractelement <16 x i32> %Src, i32 1 607 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 608 // %store i32 %Elt1, i32* %Ptr1 609 // 610 Instruction *ThenTerm = 611 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 612 /*BranchWeights=*/nullptr, DTU); 613 614 BasicBlock *CondBlock = ThenTerm->getParent(); 615 CondBlock->setName("cond.store"); 616 617 Builder.SetInsertPoint(CondBlock->getTerminator()); 618 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 619 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 620 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 621 622 // Create "else" block, fill it in the next iteration 623 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 624 NewIfBlock->setName("else"); 625 626 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 627 } 628 CI->eraseFromParent(); 629 630 ModifiedDT = true; 631 } 632 633 static void scalarizeMaskedExpandLoad(const DataLayout &DL, CallInst *CI, 634 DomTreeUpdater *DTU, bool &ModifiedDT) { 635 Value *Ptr = CI->getArgOperand(0); 636 Value *Mask = CI->getArgOperand(1); 637 Value *PassThru = CI->getArgOperand(2); 638 639 auto *VecType = cast<FixedVectorType>(CI->getType()); 640 641 Type *EltTy = VecType->getElementType(); 642 643 IRBuilder<> Builder(CI->getContext()); 644 Instruction *InsertPt = CI; 645 BasicBlock *IfBlock = CI->getParent(); 646 647 Builder.SetInsertPoint(InsertPt); 648 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 649 650 unsigned VectorWidth = VecType->getNumElements(); 651 652 // The result vector 653 Value *VResult = PassThru; 654 655 // Shorten the way if the mask is a vector of constants. 656 // Create a build_vector pattern, with loads/undefs as necessary and then 657 // shuffle blend with the pass through value. 658 if (isConstantIntVector(Mask)) { 659 unsigned MemIndex = 0; 660 VResult = PoisonValue::get(VecType); 661 SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem); 662 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 663 Value *InsertElt; 664 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) { 665 InsertElt = UndefValue::get(EltTy); 666 ShuffleMask[Idx] = Idx + VectorWidth; 667 } else { 668 Value *NewPtr = 669 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 670 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1), 671 "Load" + Twine(Idx)); 672 ShuffleMask[Idx] = Idx; 673 ++MemIndex; 674 } 675 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx, 676 "Res" + Twine(Idx)); 677 } 678 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask); 679 CI->replaceAllUsesWith(VResult); 680 CI->eraseFromParent(); 681 return; 682 } 683 684 // If the mask is not v1i1, use scalar bit test operations. This generates 685 // better results on X86 at least. 686 Value *SclrMask; 687 if (VectorWidth != 1) { 688 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 689 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 690 } 691 692 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 693 // Fill the "else" block, created in the previous iteration 694 // 695 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 696 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 697 // br i1 %mask_1, label %cond.load, label %else 698 // 699 700 Value *Predicate; 701 if (VectorWidth != 1) { 702 Value *Mask = Builder.getInt(APInt::getOneBitSet( 703 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 704 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 705 Builder.getIntN(VectorWidth, 0)); 706 } else { 707 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 708 } 709 710 // Create "cond" block 711 // 712 // %EltAddr = getelementptr i32* %1, i32 0 713 // %Elt = load i32* %EltAddr 714 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 715 // 716 Instruction *ThenTerm = 717 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 718 /*BranchWeights=*/nullptr, DTU); 719 720 BasicBlock *CondBlock = ThenTerm->getParent(); 721 CondBlock->setName("cond.load"); 722 723 Builder.SetInsertPoint(CondBlock->getTerminator()); 724 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1)); 725 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 726 727 // Move the pointer if there are more blocks to come. 728 Value *NewPtr; 729 if ((Idx + 1) != VectorWidth) 730 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 731 732 // Create "else" block, fill it in the next iteration 733 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 734 NewIfBlock->setName("else"); 735 BasicBlock *PrevIfBlock = IfBlock; 736 IfBlock = NewIfBlock; 737 738 // Create the phi to join the new and previous value. 739 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 740 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 741 ResultPhi->addIncoming(NewVResult, CondBlock); 742 ResultPhi->addIncoming(VResult, PrevIfBlock); 743 VResult = ResultPhi; 744 745 // Add a PHI for the pointer if this isn't the last iteration. 746 if ((Idx + 1) != VectorWidth) { 747 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 748 PtrPhi->addIncoming(NewPtr, CondBlock); 749 PtrPhi->addIncoming(Ptr, PrevIfBlock); 750 Ptr = PtrPhi; 751 } 752 } 753 754 CI->replaceAllUsesWith(VResult); 755 CI->eraseFromParent(); 756 757 ModifiedDT = true; 758 } 759 760 static void scalarizeMaskedCompressStore(const DataLayout &DL, CallInst *CI, 761 DomTreeUpdater *DTU, 762 bool &ModifiedDT) { 763 Value *Src = CI->getArgOperand(0); 764 Value *Ptr = CI->getArgOperand(1); 765 Value *Mask = CI->getArgOperand(2); 766 767 auto *VecType = cast<FixedVectorType>(Src->getType()); 768 769 IRBuilder<> Builder(CI->getContext()); 770 Instruction *InsertPt = CI; 771 BasicBlock *IfBlock = CI->getParent(); 772 773 Builder.SetInsertPoint(InsertPt); 774 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 775 776 Type *EltTy = VecType->getElementType(); 777 778 unsigned VectorWidth = VecType->getNumElements(); 779 780 // Shorten the way if the mask is a vector of constants. 781 if (isConstantIntVector(Mask)) { 782 unsigned MemIndex = 0; 783 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 784 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 785 continue; 786 Value *OneElt = 787 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 788 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 789 Builder.CreateAlignedStore(OneElt, NewPtr, Align(1)); 790 ++MemIndex; 791 } 792 CI->eraseFromParent(); 793 return; 794 } 795 796 // If the mask is not v1i1, use scalar bit test operations. This generates 797 // better results on X86 at least. 798 Value *SclrMask; 799 if (VectorWidth != 1) { 800 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 801 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 802 } 803 804 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 805 // Fill the "else" block, created in the previous iteration 806 // 807 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 808 // br i1 %mask_1, label %cond.store, label %else 809 // 810 Value *Predicate; 811 if (VectorWidth != 1) { 812 Value *Mask = Builder.getInt(APInt::getOneBitSet( 813 VectorWidth, adjustForEndian(DL, VectorWidth, Idx))); 814 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 815 Builder.getIntN(VectorWidth, 0)); 816 } else { 817 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 818 } 819 820 // Create "cond" block 821 // 822 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 823 // %EltAddr = getelementptr i32* %1, i32 0 824 // %store i32 %OneElt, i32* %EltAddr 825 // 826 Instruction *ThenTerm = 827 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 828 /*BranchWeights=*/nullptr, DTU); 829 830 BasicBlock *CondBlock = ThenTerm->getParent(); 831 CondBlock->setName("cond.store"); 832 833 Builder.SetInsertPoint(CondBlock->getTerminator()); 834 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 835 Builder.CreateAlignedStore(OneElt, Ptr, Align(1)); 836 837 // Move the pointer if there are more blocks to come. 838 Value *NewPtr; 839 if ((Idx + 1) != VectorWidth) 840 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 841 842 // Create "else" block, fill it in the next iteration 843 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 844 NewIfBlock->setName("else"); 845 BasicBlock *PrevIfBlock = IfBlock; 846 IfBlock = NewIfBlock; 847 848 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 849 850 // Add a PHI for the pointer if this isn't the last iteration. 851 if ((Idx + 1) != VectorWidth) { 852 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 853 PtrPhi->addIncoming(NewPtr, CondBlock); 854 PtrPhi->addIncoming(Ptr, PrevIfBlock); 855 Ptr = PtrPhi; 856 } 857 } 858 CI->eraseFromParent(); 859 860 ModifiedDT = true; 861 } 862 863 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 864 DominatorTree *DT) { 865 std::optional<DomTreeUpdater> DTU; 866 if (DT) 867 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy); 868 869 bool EverMadeChange = false; 870 bool MadeChange = true; 871 auto &DL = F.getParent()->getDataLayout(); 872 while (MadeChange) { 873 MadeChange = false; 874 for (BasicBlock &BB : llvm::make_early_inc_range(F)) { 875 bool ModifiedDTOnIteration = false; 876 MadeChange |= optimizeBlock(BB, ModifiedDTOnIteration, TTI, DL, 877 DTU ? &*DTU : nullptr); 878 879 // Restart BB iteration if the dominator tree of the Function was changed 880 if (ModifiedDTOnIteration) 881 break; 882 } 883 884 EverMadeChange |= MadeChange; 885 } 886 return EverMadeChange; 887 } 888 889 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) { 890 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 891 DominatorTree *DT = nullptr; 892 if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>()) 893 DT = &DTWP->getDomTree(); 894 return runImpl(F, TTI, DT); 895 } 896 897 PreservedAnalyses 898 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) { 899 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 900 auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); 901 if (!runImpl(F, TTI, DT)) 902 return PreservedAnalyses::all(); 903 PreservedAnalyses PA; 904 PA.preserve<TargetIRAnalysis>(); 905 PA.preserve<DominatorTreeAnalysis>(); 906 return PA; 907 } 908 909 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, 910 const TargetTransformInfo &TTI, const DataLayout &DL, 911 DomTreeUpdater *DTU) { 912 bool MadeChange = false; 913 914 BasicBlock::iterator CurInstIterator = BB.begin(); 915 while (CurInstIterator != BB.end()) { 916 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++)) 917 MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU); 918 if (ModifiedDT) 919 return true; 920 } 921 922 return MadeChange; 923 } 924 925 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, 926 const TargetTransformInfo &TTI, 927 const DataLayout &DL, DomTreeUpdater *DTU) { 928 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 929 if (II) { 930 // The scalarization code below does not work for scalable vectors. 931 if (isa<ScalableVectorType>(II->getType()) || 932 any_of(II->args(), 933 [](Value *V) { return isa<ScalableVectorType>(V->getType()); })) 934 return false; 935 936 switch (II->getIntrinsicID()) { 937 default: 938 break; 939 case Intrinsic::masked_load: 940 // Scalarize unsupported vector masked load 941 if (TTI.isLegalMaskedLoad( 942 CI->getType(), 943 cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue())) 944 return false; 945 scalarizeMaskedLoad(DL, CI, DTU, ModifiedDT); 946 return true; 947 case Intrinsic::masked_store: 948 if (TTI.isLegalMaskedStore( 949 CI->getArgOperand(0)->getType(), 950 cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue())) 951 return false; 952 scalarizeMaskedStore(DL, CI, DTU, ModifiedDT); 953 return true; 954 case Intrinsic::masked_gather: { 955 MaybeAlign MA = 956 cast<ConstantInt>(CI->getArgOperand(1))->getMaybeAlignValue(); 957 Type *LoadTy = CI->getType(); 958 Align Alignment = DL.getValueOrABITypeAlignment(MA, 959 LoadTy->getScalarType()); 960 if (TTI.isLegalMaskedGather(LoadTy, Alignment) && 961 !TTI.forceScalarizeMaskedGather(cast<VectorType>(LoadTy), Alignment)) 962 return false; 963 scalarizeMaskedGather(DL, CI, DTU, ModifiedDT); 964 return true; 965 } 966 case Intrinsic::masked_scatter: { 967 MaybeAlign MA = 968 cast<ConstantInt>(CI->getArgOperand(2))->getMaybeAlignValue(); 969 Type *StoreTy = CI->getArgOperand(0)->getType(); 970 Align Alignment = DL.getValueOrABITypeAlignment(MA, 971 StoreTy->getScalarType()); 972 if (TTI.isLegalMaskedScatter(StoreTy, Alignment) && 973 !TTI.forceScalarizeMaskedScatter(cast<VectorType>(StoreTy), 974 Alignment)) 975 return false; 976 scalarizeMaskedScatter(DL, CI, DTU, ModifiedDT); 977 return true; 978 } 979 case Intrinsic::masked_expandload: 980 if (TTI.isLegalMaskedExpandLoad(CI->getType())) 981 return false; 982 scalarizeMaskedExpandLoad(DL, CI, DTU, ModifiedDT); 983 return true; 984 case Intrinsic::masked_compressstore: 985 if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) 986 return false; 987 scalarizeMaskedCompressStore(DL, CI, DTU, ModifiedDT); 988 return true; 989 } 990 } 991 992 return false; 993 } 994