1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// 2 // instrinsics 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This pass replaces masked memory intrinsics - when unsupported by the target 11 // - with a chain of basic blocks, that deal with the elements one-by-one if the 12 // appropriate mask bit is set. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h" 17 #include "llvm/ADT/Twine.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/IR/BasicBlock.h" 20 #include "llvm/IR/Constant.h" 21 #include "llvm/IR/Constants.h" 22 #include "llvm/IR/DerivedTypes.h" 23 #include "llvm/IR/Function.h" 24 #include "llvm/IR/IRBuilder.h" 25 #include "llvm/IR/InstrTypes.h" 26 #include "llvm/IR/Instruction.h" 27 #include "llvm/IR/Instructions.h" 28 #include "llvm/IR/IntrinsicInst.h" 29 #include "llvm/IR/Intrinsics.h" 30 #include "llvm/IR/Type.h" 31 #include "llvm/IR/Value.h" 32 #include "llvm/InitializePasses.h" 33 #include "llvm/Pass.h" 34 #include "llvm/Support/Casting.h" 35 #include "llvm/Transforms/Scalar.h" 36 #include <algorithm> 37 #include <cassert> 38 39 using namespace llvm; 40 41 #define DEBUG_TYPE "scalarize-masked-mem-intrin" 42 43 namespace { 44 45 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass { 46 public: 47 static char ID; // Pass identification, replacement for typeid 48 49 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) { 50 initializeScalarizeMaskedMemIntrinLegacyPassPass( 51 *PassRegistry::getPassRegistry()); 52 } 53 54 bool runOnFunction(Function &F) override; 55 56 StringRef getPassName() const override { 57 return "Scalarize Masked Memory Intrinsics"; 58 } 59 60 void getAnalysisUsage(AnalysisUsage &AU) const override { 61 AU.addRequired<TargetTransformInfoWrapperPass>(); 62 } 63 }; 64 65 } // end anonymous namespace 66 67 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, 68 const TargetTransformInfo &TTI, const DataLayout &DL); 69 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, 70 const TargetTransformInfo &TTI, 71 const DataLayout &DL); 72 73 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0; 74 75 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, 76 "Scalarize unsupported masked memory intrinsics", false, 77 false) 78 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 79 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, 80 "Scalarize unsupported masked memory intrinsics", false, 81 false) 82 83 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() { 84 return new ScalarizeMaskedMemIntrinLegacyPass(); 85 } 86 87 static bool isConstantIntVector(Value *Mask) { 88 Constant *C = dyn_cast<Constant>(Mask); 89 if (!C) 90 return false; 91 92 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements(); 93 for (unsigned i = 0; i != NumElts; ++i) { 94 Constant *CElt = C->getAggregateElement(i); 95 if (!CElt || !isa<ConstantInt>(CElt)) 96 return false; 97 } 98 99 return true; 100 } 101 102 // Translate a masked load intrinsic like 103 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, 104 // <16 x i1> %mask, <16 x i32> %passthru) 105 // to a chain of basic blocks, with loading element one-by-one if 106 // the appropriate mask bit is set 107 // 108 // %1 = bitcast i8* %addr to i32* 109 // %2 = extractelement <16 x i1> %mask, i32 0 110 // br i1 %2, label %cond.load, label %else 111 // 112 // cond.load: ; preds = %0 113 // %3 = getelementptr i32* %1, i32 0 114 // %4 = load i32* %3 115 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 116 // br label %else 117 // 118 // else: ; preds = %0, %cond.load 119 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ] 120 // %6 = extractelement <16 x i1> %mask, i32 1 121 // br i1 %6, label %cond.load1, label %else2 122 // 123 // cond.load1: ; preds = %else 124 // %7 = getelementptr i32* %1, i32 1 125 // %8 = load i32* %7 126 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 127 // br label %else2 128 // 129 // else2: ; preds = %else, %cond.load1 130 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] 131 // %10 = extractelement <16 x i1> %mask, i32 2 132 // br i1 %10, label %cond.load4, label %else5 133 // 134 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) { 135 Value *Ptr = CI->getArgOperand(0); 136 Value *Alignment = CI->getArgOperand(1); 137 Value *Mask = CI->getArgOperand(2); 138 Value *Src0 = CI->getArgOperand(3); 139 140 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); 141 VectorType *VecType = cast<FixedVectorType>(CI->getType()); 142 143 Type *EltTy = VecType->getElementType(); 144 145 IRBuilder<> Builder(CI->getContext()); 146 Instruction *InsertPt = CI; 147 BasicBlock *IfBlock = CI->getParent(); 148 149 Builder.SetInsertPoint(InsertPt); 150 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 151 152 // Short-cut if the mask is all-true. 153 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 154 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal); 155 CI->replaceAllUsesWith(NewI); 156 CI->eraseFromParent(); 157 return; 158 } 159 160 // Adjust alignment for the scalar instruction. 161 const Align AdjustedAlignVal = 162 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 163 // Bitcast %addr from i8* to EltTy* 164 Type *NewPtrType = 165 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); 166 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 167 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); 168 169 // The result vector 170 Value *VResult = Src0; 171 172 if (isConstantIntVector(Mask)) { 173 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 174 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 175 continue; 176 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 177 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); 178 VResult = Builder.CreateInsertElement(VResult, Load, Idx); 179 } 180 CI->replaceAllUsesWith(VResult); 181 CI->eraseFromParent(); 182 return; 183 } 184 185 // If the mask is not v1i1, use scalar bit test operations. This generates 186 // better results on X86 at least. 187 Value *SclrMask; 188 if (VectorWidth != 1) { 189 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 190 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 191 } 192 193 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 194 // Fill the "else" block, created in the previous iteration 195 // 196 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 197 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 198 // %cond = icmp ne i16 %mask_1, 0 199 // br i1 %mask_1, label %cond.load, label %else 200 // 201 Value *Predicate; 202 if (VectorWidth != 1) { 203 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 204 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 205 Builder.getIntN(VectorWidth, 0)); 206 } else { 207 Predicate = Builder.CreateExtractElement(Mask, Idx); 208 } 209 210 // Create "cond" block 211 // 212 // %EltAddr = getelementptr i32* %1, i32 0 213 // %Elt = load i32* %EltAddr 214 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 215 // 216 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), 217 "cond.load"); 218 Builder.SetInsertPoint(InsertPt); 219 220 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 221 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); 222 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 223 224 // Create "else" block, fill it in the next iteration 225 BasicBlock *NewIfBlock = 226 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 227 Builder.SetInsertPoint(InsertPt); 228 Instruction *OldBr = IfBlock->getTerminator(); 229 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 230 OldBr->eraseFromParent(); 231 BasicBlock *PrevIfBlock = IfBlock; 232 IfBlock = NewIfBlock; 233 234 // Create the phi to join the new and previous value. 235 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 236 Phi->addIncoming(NewVResult, CondBlock); 237 Phi->addIncoming(VResult, PrevIfBlock); 238 VResult = Phi; 239 } 240 241 CI->replaceAllUsesWith(VResult); 242 CI->eraseFromParent(); 243 244 ModifiedDT = true; 245 } 246 247 // Translate a masked store intrinsic, like 248 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, 249 // <16 x i1> %mask) 250 // to a chain of basic blocks, that stores element one-by-one if 251 // the appropriate mask bit is set 252 // 253 // %1 = bitcast i8* %addr to i32* 254 // %2 = extractelement <16 x i1> %mask, i32 0 255 // br i1 %2, label %cond.store, label %else 256 // 257 // cond.store: ; preds = %0 258 // %3 = extractelement <16 x i32> %val, i32 0 259 // %4 = getelementptr i32* %1, i32 0 260 // store i32 %3, i32* %4 261 // br label %else 262 // 263 // else: ; preds = %0, %cond.store 264 // %5 = extractelement <16 x i1> %mask, i32 1 265 // br i1 %5, label %cond.store1, label %else2 266 // 267 // cond.store1: ; preds = %else 268 // %6 = extractelement <16 x i32> %val, i32 1 269 // %7 = getelementptr i32* %1, i32 1 270 // store i32 %6, i32* %7 271 // br label %else2 272 // . . . 273 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) { 274 Value *Src = CI->getArgOperand(0); 275 Value *Ptr = CI->getArgOperand(1); 276 Value *Alignment = CI->getArgOperand(2); 277 Value *Mask = CI->getArgOperand(3); 278 279 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); 280 auto *VecType = cast<VectorType>(Src->getType()); 281 282 Type *EltTy = VecType->getElementType(); 283 284 IRBuilder<> Builder(CI->getContext()); 285 Instruction *InsertPt = CI; 286 BasicBlock *IfBlock = CI->getParent(); 287 Builder.SetInsertPoint(InsertPt); 288 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 289 290 // Short-cut if the mask is all-true. 291 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 292 Builder.CreateAlignedStore(Src, Ptr, AlignVal); 293 CI->eraseFromParent(); 294 return; 295 } 296 297 // Adjust alignment for the scalar instruction. 298 const Align AdjustedAlignVal = 299 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 300 // Bitcast %addr from i8* to EltTy* 301 Type *NewPtrType = 302 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); 303 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 304 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); 305 306 if (isConstantIntVector(Mask)) { 307 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 308 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 309 continue; 310 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 311 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 312 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); 313 } 314 CI->eraseFromParent(); 315 return; 316 } 317 318 // If the mask is not v1i1, use scalar bit test operations. This generates 319 // better results on X86 at least. 320 Value *SclrMask; 321 if (VectorWidth != 1) { 322 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 323 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 324 } 325 326 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 327 // Fill the "else" block, created in the previous iteration 328 // 329 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 330 // %cond = icmp ne i16 %mask_1, 0 331 // br i1 %mask_1, label %cond.store, label %else 332 // 333 Value *Predicate; 334 if (VectorWidth != 1) { 335 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 336 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 337 Builder.getIntN(VectorWidth, 0)); 338 } else { 339 Predicate = Builder.CreateExtractElement(Mask, Idx); 340 } 341 342 // Create "cond" block 343 // 344 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 345 // %EltAddr = getelementptr i32* %1, i32 0 346 // %store i32 %OneElt, i32* %EltAddr 347 // 348 BasicBlock *CondBlock = 349 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); 350 Builder.SetInsertPoint(InsertPt); 351 352 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 353 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 354 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); 355 356 // Create "else" block, fill it in the next iteration 357 BasicBlock *NewIfBlock = 358 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 359 Builder.SetInsertPoint(InsertPt); 360 Instruction *OldBr = IfBlock->getTerminator(); 361 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 362 OldBr->eraseFromParent(); 363 IfBlock = NewIfBlock; 364 } 365 CI->eraseFromParent(); 366 367 ModifiedDT = true; 368 } 369 370 // Translate a masked gather intrinsic like 371 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, 372 // <16 x i1> %Mask, <16 x i32> %Src) 373 // to a chain of basic blocks, with loading element one-by-one if 374 // the appropriate mask bit is set 375 // 376 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind 377 // %Mask0 = extractelement <16 x i1> %Mask, i32 0 378 // br i1 %Mask0, label %cond.load, label %else 379 // 380 // cond.load: 381 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 382 // %Load0 = load i32, i32* %Ptr0, align 4 383 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0 384 // br label %else 385 // 386 // else: 387 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0] 388 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 389 // br i1 %Mask1, label %cond.load1, label %else2 390 // 391 // cond.load1: 392 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 393 // %Load1 = load i32, i32* %Ptr1, align 4 394 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 395 // br label %else2 396 // . . . 397 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src 398 // ret <16 x i32> %Result 399 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) { 400 Value *Ptrs = CI->getArgOperand(0); 401 Value *Alignment = CI->getArgOperand(1); 402 Value *Mask = CI->getArgOperand(2); 403 Value *Src0 = CI->getArgOperand(3); 404 405 auto *VecType = cast<FixedVectorType>(CI->getType()); 406 Type *EltTy = VecType->getElementType(); 407 408 IRBuilder<> Builder(CI->getContext()); 409 Instruction *InsertPt = CI; 410 BasicBlock *IfBlock = CI->getParent(); 411 Builder.SetInsertPoint(InsertPt); 412 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); 413 414 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 415 416 // The result vector 417 Value *VResult = Src0; 418 unsigned VectorWidth = VecType->getNumElements(); 419 420 // Shorten the way if the mask is a vector of constants. 421 if (isConstantIntVector(Mask)) { 422 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 423 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 424 continue; 425 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 426 LoadInst *Load = 427 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 428 VResult = 429 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 430 } 431 CI->replaceAllUsesWith(VResult); 432 CI->eraseFromParent(); 433 return; 434 } 435 436 // If the mask is not v1i1, use scalar bit test operations. This generates 437 // better results on X86 at least. 438 Value *SclrMask; 439 if (VectorWidth != 1) { 440 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 441 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 442 } 443 444 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 445 // Fill the "else" block, created in the previous iteration 446 // 447 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 448 // %cond = icmp ne i16 %mask_1, 0 449 // br i1 %Mask1, label %cond.load, label %else 450 // 451 452 Value *Predicate; 453 if (VectorWidth != 1) { 454 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 455 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 456 Builder.getIntN(VectorWidth, 0)); 457 } else { 458 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 459 } 460 461 // Create "cond" block 462 // 463 // %EltAddr = getelementptr i32* %1, i32 0 464 // %Elt = load i32* %EltAddr 465 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 466 // 467 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load"); 468 Builder.SetInsertPoint(InsertPt); 469 470 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 471 LoadInst *Load = 472 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 473 Value *NewVResult = 474 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 475 476 // Create "else" block, fill it in the next iteration 477 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); 478 Builder.SetInsertPoint(InsertPt); 479 Instruction *OldBr = IfBlock->getTerminator(); 480 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 481 OldBr->eraseFromParent(); 482 BasicBlock *PrevIfBlock = IfBlock; 483 IfBlock = NewIfBlock; 484 485 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 486 Phi->addIncoming(NewVResult, CondBlock); 487 Phi->addIncoming(VResult, PrevIfBlock); 488 VResult = Phi; 489 } 490 491 CI->replaceAllUsesWith(VResult); 492 CI->eraseFromParent(); 493 494 ModifiedDT = true; 495 } 496 497 // Translate a masked scatter intrinsic, like 498 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, 499 // <16 x i1> %Mask) 500 // to a chain of basic blocks, that stores element one-by-one if 501 // the appropriate mask bit is set. 502 // 503 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind 504 // %Mask0 = extractelement <16 x i1> %Mask, i32 0 505 // br i1 %Mask0, label %cond.store, label %else 506 // 507 // cond.store: 508 // %Elt0 = extractelement <16 x i32> %Src, i32 0 509 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 510 // store i32 %Elt0, i32* %Ptr0, align 4 511 // br label %else 512 // 513 // else: 514 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 515 // br i1 %Mask1, label %cond.store1, label %else2 516 // 517 // cond.store1: 518 // %Elt1 = extractelement <16 x i32> %Src, i32 1 519 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 520 // store i32 %Elt1, i32* %Ptr1, align 4 521 // br label %else2 522 // . . . 523 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) { 524 Value *Src = CI->getArgOperand(0); 525 Value *Ptrs = CI->getArgOperand(1); 526 Value *Alignment = CI->getArgOperand(2); 527 Value *Mask = CI->getArgOperand(3); 528 529 auto *SrcFVTy = cast<FixedVectorType>(Src->getType()); 530 531 assert( 532 isa<VectorType>(Ptrs->getType()) && 533 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) && 534 "Vector of pointers is expected in masked scatter intrinsic"); 535 536 IRBuilder<> Builder(CI->getContext()); 537 Instruction *InsertPt = CI; 538 BasicBlock *IfBlock = CI->getParent(); 539 Builder.SetInsertPoint(InsertPt); 540 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 541 542 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); 543 unsigned VectorWidth = SrcFVTy->getNumElements(); 544 545 // Shorten the way if the mask is a vector of constants. 546 if (isConstantIntVector(Mask)) { 547 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 548 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 549 continue; 550 Value *OneElt = 551 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 552 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 553 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 554 } 555 CI->eraseFromParent(); 556 return; 557 } 558 559 // If the mask is not v1i1, use scalar bit test operations. This generates 560 // better results on X86 at least. 561 Value *SclrMask; 562 if (VectorWidth != 1) { 563 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 564 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 565 } 566 567 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 568 // Fill the "else" block, created in the previous iteration 569 // 570 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 571 // %cond = icmp ne i16 %mask_1, 0 572 // br i1 %Mask1, label %cond.store, label %else 573 // 574 Value *Predicate; 575 if (VectorWidth != 1) { 576 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 577 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 578 Builder.getIntN(VectorWidth, 0)); 579 } else { 580 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 581 } 582 583 // Create "cond" block 584 // 585 // %Elt1 = extractelement <16 x i32> %Src, i32 1 586 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 587 // %store i32 %Elt1, i32* %Ptr1 588 // 589 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store"); 590 Builder.SetInsertPoint(InsertPt); 591 592 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 593 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 594 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 595 596 // Create "else" block, fill it in the next iteration 597 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); 598 Builder.SetInsertPoint(InsertPt); 599 Instruction *OldBr = IfBlock->getTerminator(); 600 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 601 OldBr->eraseFromParent(); 602 IfBlock = NewIfBlock; 603 } 604 CI->eraseFromParent(); 605 606 ModifiedDT = true; 607 } 608 609 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) { 610 Value *Ptr = CI->getArgOperand(0); 611 Value *Mask = CI->getArgOperand(1); 612 Value *PassThru = CI->getArgOperand(2); 613 614 auto *VecType = cast<FixedVectorType>(CI->getType()); 615 616 Type *EltTy = VecType->getElementType(); 617 618 IRBuilder<> Builder(CI->getContext()); 619 Instruction *InsertPt = CI; 620 BasicBlock *IfBlock = CI->getParent(); 621 622 Builder.SetInsertPoint(InsertPt); 623 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 624 625 unsigned VectorWidth = VecType->getNumElements(); 626 627 // The result vector 628 Value *VResult = PassThru; 629 630 // Shorten the way if the mask is a vector of constants. 631 // Create a build_vector pattern, with loads/undefs as necessary and then 632 // shuffle blend with the pass through value. 633 if (isConstantIntVector(Mask)) { 634 unsigned MemIndex = 0; 635 VResult = UndefValue::get(VecType); 636 SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem); 637 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 638 Value *InsertElt; 639 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) { 640 InsertElt = UndefValue::get(EltTy); 641 ShuffleMask[Idx] = Idx + VectorWidth; 642 } else { 643 Value *NewPtr = 644 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 645 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1), 646 "Load" + Twine(Idx)); 647 ShuffleMask[Idx] = Idx; 648 ++MemIndex; 649 } 650 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx, 651 "Res" + Twine(Idx)); 652 } 653 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask); 654 CI->replaceAllUsesWith(VResult); 655 CI->eraseFromParent(); 656 return; 657 } 658 659 // If the mask is not v1i1, use scalar bit test operations. This generates 660 // better results on X86 at least. 661 Value *SclrMask; 662 if (VectorWidth != 1) { 663 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 664 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 665 } 666 667 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 668 // Fill the "else" block, created in the previous iteration 669 // 670 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 671 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 672 // br i1 %mask_1, label %cond.load, label %else 673 // 674 675 Value *Predicate; 676 if (VectorWidth != 1) { 677 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 678 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 679 Builder.getIntN(VectorWidth, 0)); 680 } else { 681 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 682 } 683 684 // Create "cond" block 685 // 686 // %EltAddr = getelementptr i32* %1, i32 0 687 // %Elt = load i32* %EltAddr 688 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 689 // 690 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), 691 "cond.load"); 692 Builder.SetInsertPoint(InsertPt); 693 694 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1)); 695 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 696 697 // Move the pointer if there are more blocks to come. 698 Value *NewPtr; 699 if ((Idx + 1) != VectorWidth) 700 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 701 702 // Create "else" block, fill it in the next iteration 703 BasicBlock *NewIfBlock = 704 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 705 Builder.SetInsertPoint(InsertPt); 706 Instruction *OldBr = IfBlock->getTerminator(); 707 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 708 OldBr->eraseFromParent(); 709 BasicBlock *PrevIfBlock = IfBlock; 710 IfBlock = NewIfBlock; 711 712 // Create the phi to join the new and previous value. 713 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 714 ResultPhi->addIncoming(NewVResult, CondBlock); 715 ResultPhi->addIncoming(VResult, PrevIfBlock); 716 VResult = ResultPhi; 717 718 // Add a PHI for the pointer if this isn't the last iteration. 719 if ((Idx + 1) != VectorWidth) { 720 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 721 PtrPhi->addIncoming(NewPtr, CondBlock); 722 PtrPhi->addIncoming(Ptr, PrevIfBlock); 723 Ptr = PtrPhi; 724 } 725 } 726 727 CI->replaceAllUsesWith(VResult); 728 CI->eraseFromParent(); 729 730 ModifiedDT = true; 731 } 732 733 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) { 734 Value *Src = CI->getArgOperand(0); 735 Value *Ptr = CI->getArgOperand(1); 736 Value *Mask = CI->getArgOperand(2); 737 738 auto *VecType = cast<FixedVectorType>(Src->getType()); 739 740 IRBuilder<> Builder(CI->getContext()); 741 Instruction *InsertPt = CI; 742 BasicBlock *IfBlock = CI->getParent(); 743 744 Builder.SetInsertPoint(InsertPt); 745 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 746 747 Type *EltTy = VecType->getElementType(); 748 749 unsigned VectorWidth = VecType->getNumElements(); 750 751 // Shorten the way if the mask is a vector of constants. 752 if (isConstantIntVector(Mask)) { 753 unsigned MemIndex = 0; 754 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 755 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 756 continue; 757 Value *OneElt = 758 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 759 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 760 Builder.CreateAlignedStore(OneElt, NewPtr, Align(1)); 761 ++MemIndex; 762 } 763 CI->eraseFromParent(); 764 return; 765 } 766 767 // If the mask is not v1i1, use scalar bit test operations. This generates 768 // better results on X86 at least. 769 Value *SclrMask; 770 if (VectorWidth != 1) { 771 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 772 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 773 } 774 775 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 776 // Fill the "else" block, created in the previous iteration 777 // 778 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 779 // br i1 %mask_1, label %cond.store, label %else 780 // 781 Value *Predicate; 782 if (VectorWidth != 1) { 783 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 784 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 785 Builder.getIntN(VectorWidth, 0)); 786 } else { 787 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 788 } 789 790 // Create "cond" block 791 // 792 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 793 // %EltAddr = getelementptr i32* %1, i32 0 794 // %store i32 %OneElt, i32* %EltAddr 795 // 796 BasicBlock *CondBlock = 797 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); 798 Builder.SetInsertPoint(InsertPt); 799 800 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 801 Builder.CreateAlignedStore(OneElt, Ptr, Align(1)); 802 803 // Move the pointer if there are more blocks to come. 804 Value *NewPtr; 805 if ((Idx + 1) != VectorWidth) 806 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 807 808 // Create "else" block, fill it in the next iteration 809 BasicBlock *NewIfBlock = 810 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 811 Builder.SetInsertPoint(InsertPt); 812 Instruction *OldBr = IfBlock->getTerminator(); 813 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 814 OldBr->eraseFromParent(); 815 BasicBlock *PrevIfBlock = IfBlock; 816 IfBlock = NewIfBlock; 817 818 // Add a PHI for the pointer if this isn't the last iteration. 819 if ((Idx + 1) != VectorWidth) { 820 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 821 PtrPhi->addIncoming(NewPtr, CondBlock); 822 PtrPhi->addIncoming(Ptr, PrevIfBlock); 823 Ptr = PtrPhi; 824 } 825 } 826 CI->eraseFromParent(); 827 828 ModifiedDT = true; 829 } 830 831 static bool runImpl(Function &F, const TargetTransformInfo &TTI) { 832 bool EverMadeChange = false; 833 bool MadeChange = true; 834 auto &DL = F.getParent()->getDataLayout(); 835 while (MadeChange) { 836 MadeChange = false; 837 for (Function::iterator I = F.begin(); I != F.end();) { 838 BasicBlock *BB = &*I++; 839 bool ModifiedDTOnIteration = false; 840 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration, TTI, DL); 841 842 // Restart BB iteration if the dominator tree of the Function was changed 843 if (ModifiedDTOnIteration) 844 break; 845 } 846 847 EverMadeChange |= MadeChange; 848 } 849 return EverMadeChange; 850 } 851 852 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) { 853 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 854 return runImpl(F, TTI); 855 } 856 857 PreservedAnalyses 858 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) { 859 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 860 if (!runImpl(F, TTI)) 861 return PreservedAnalyses::all(); 862 PreservedAnalyses PA; 863 PA.preserve<TargetIRAnalysis>(); 864 return PA; 865 } 866 867 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, 868 const TargetTransformInfo &TTI, 869 const DataLayout &DL) { 870 bool MadeChange = false; 871 872 BasicBlock::iterator CurInstIterator = BB.begin(); 873 while (CurInstIterator != BB.end()) { 874 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++)) 875 MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL); 876 if (ModifiedDT) 877 return true; 878 } 879 880 return MadeChange; 881 } 882 883 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, 884 const TargetTransformInfo &TTI, 885 const DataLayout &DL) { 886 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 887 if (II) { 888 // The scalarization code below does not work for scalable vectors. 889 if (isa<ScalableVectorType>(II->getType()) || 890 any_of(II->arg_operands(), 891 [](Value *V) { return isa<ScalableVectorType>(V->getType()); })) 892 return false; 893 894 switch (II->getIntrinsicID()) { 895 default: 896 break; 897 case Intrinsic::masked_load: 898 // Scalarize unsupported vector masked load 899 if (TTI.isLegalMaskedLoad( 900 CI->getType(), 901 cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue())) 902 return false; 903 scalarizeMaskedLoad(CI, ModifiedDT); 904 return true; 905 case Intrinsic::masked_store: 906 if (TTI.isLegalMaskedStore( 907 CI->getArgOperand(0)->getType(), 908 cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue())) 909 return false; 910 scalarizeMaskedStore(CI, ModifiedDT); 911 return true; 912 case Intrinsic::masked_gather: { 913 unsigned AlignmentInt = 914 cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue(); 915 Type *LoadTy = CI->getType(); 916 Align Alignment = 917 DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), LoadTy); 918 if (TTI.isLegalMaskedGather(LoadTy, Alignment)) 919 return false; 920 scalarizeMaskedGather(CI, ModifiedDT); 921 return true; 922 } 923 case Intrinsic::masked_scatter: { 924 unsigned AlignmentInt = 925 cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue(); 926 Type *StoreTy = CI->getArgOperand(0)->getType(); 927 Align Alignment = 928 DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), StoreTy); 929 if (TTI.isLegalMaskedScatter(StoreTy, Alignment)) 930 return false; 931 scalarizeMaskedScatter(CI, ModifiedDT); 932 return true; 933 } 934 case Intrinsic::masked_expandload: 935 if (TTI.isLegalMaskedExpandLoad(CI->getType())) 936 return false; 937 scalarizeMaskedExpandLoad(CI, ModifiedDT); 938 return true; 939 case Intrinsic::masked_compressstore: 940 if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) 941 return false; 942 scalarizeMaskedCompressStore(CI, ModifiedDT); 943 return true; 944 } 945 } 946 947 return false; 948 } 949