1 //===-- X86PartialReduction.cpp -------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass looks for add instructions used by a horizontal reduction to see 10 // if we might be able to use pmaddwd or psadbw. Some cases of this require 11 // cross basic block knowledge and can't be done in SelectionDAG. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "X86.h" 16 #include "X86TargetMachine.h" 17 #include "llvm/Analysis/ValueTracking.h" 18 #include "llvm/CodeGen/TargetPassConfig.h" 19 #include "llvm/IR/Constants.h" 20 #include "llvm/IR/IRBuilder.h" 21 #include "llvm/IR/Instructions.h" 22 #include "llvm/IR/IntrinsicsX86.h" 23 #include "llvm/IR/Operator.h" 24 #include "llvm/Pass.h" 25 #include "llvm/Support/KnownBits.h" 26 27 using namespace llvm; 28 29 #define DEBUG_TYPE "x86-partial-reduction" 30 31 namespace { 32 33 class X86PartialReduction : public FunctionPass { 34 const DataLayout *DL; 35 const X86Subtarget *ST; 36 37 public: 38 static char ID; // Pass identification, replacement for typeid. 39 40 X86PartialReduction() : FunctionPass(ID) { } 41 42 bool runOnFunction(Function &Fn) override; 43 44 void getAnalysisUsage(AnalysisUsage &AU) const override { 45 AU.setPreservesCFG(); 46 } 47 48 StringRef getPassName() const override { 49 return "X86 Partial Reduction"; 50 } 51 52 private: 53 bool tryMAddReplacement(Instruction *Op, bool ReduceInOneBB); 54 bool trySADReplacement(Instruction *Op); 55 }; 56 } 57 58 FunctionPass *llvm::createX86PartialReductionPass() { 59 return new X86PartialReduction(); 60 } 61 62 char X86PartialReduction::ID = 0; 63 64 INITIALIZE_PASS(X86PartialReduction, DEBUG_TYPE, 65 "X86 Partial Reduction", false, false) 66 67 // This function should be aligned with detectExtMul() in X86ISelLowering.cpp. 68 static bool matchVPDPBUSDPattern(const X86Subtarget *ST, BinaryOperator *Mul, 69 const DataLayout *DL) { 70 if (!ST->hasVNNI() && !ST->hasAVXVNNI()) 71 return false; 72 73 Value *LHS = Mul->getOperand(0); 74 Value *RHS = Mul->getOperand(1); 75 76 if (isa<SExtInst>(LHS)) 77 std::swap(LHS, RHS); 78 79 auto IsFreeTruncation = [&](Value *Op) { 80 if (auto *Cast = dyn_cast<CastInst>(Op)) { 81 if (Cast->getParent() == Mul->getParent() && 82 (Cast->getOpcode() == Instruction::SExt || 83 Cast->getOpcode() == Instruction::ZExt) && 84 Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 8) 85 return true; 86 } 87 88 return isa<Constant>(Op); 89 }; 90 91 // (dpbusd (zext a), (sext, b)). Since the first operand should be unsigned 92 // value, we need to check LHS is zero extended value. RHS should be signed 93 // value, so we just check the signed bits. 94 if ((IsFreeTruncation(LHS) && 95 computeKnownBits(LHS, *DL).countMaxActiveBits() <= 8) && 96 (IsFreeTruncation(RHS) && ComputeMaxSignificantBits(RHS, *DL) <= 8)) 97 return true; 98 99 return false; 100 } 101 102 bool X86PartialReduction::tryMAddReplacement(Instruction *Op, 103 bool ReduceInOneBB) { 104 if (!ST->hasSSE2()) 105 return false; 106 107 // Need at least 8 elements. 108 if (cast<FixedVectorType>(Op->getType())->getNumElements() < 8) 109 return false; 110 111 // Element type should be i32. 112 if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32)) 113 return false; 114 115 auto *Mul = dyn_cast<BinaryOperator>(Op); 116 if (!Mul || Mul->getOpcode() != Instruction::Mul) 117 return false; 118 119 Value *LHS = Mul->getOperand(0); 120 Value *RHS = Mul->getOperand(1); 121 122 // If the target support VNNI, leave it to ISel to combine reduce operation 123 // to VNNI instruction. 124 // TODO: we can support transforming reduce to VNNI intrinsic for across block 125 // in this pass. 126 if (ReduceInOneBB && matchVPDPBUSDPattern(ST, Mul, DL)) 127 return false; 128 129 // LHS and RHS should be only used once or if they are the same then only 130 // used twice. Only check this when SSE4.1 is enabled and we have zext/sext 131 // instructions, otherwise we use punpck to emulate zero extend in stages. The 132 // trunc/ we need to do likely won't introduce new instructions in that case. 133 if (ST->hasSSE41()) { 134 if (LHS == RHS) { 135 if (!isa<Constant>(LHS) && !LHS->hasNUses(2)) 136 return false; 137 } else { 138 if (!isa<Constant>(LHS) && !LHS->hasOneUse()) 139 return false; 140 if (!isa<Constant>(RHS) && !RHS->hasOneUse()) 141 return false; 142 } 143 } 144 145 auto CanShrinkOp = [&](Value *Op) { 146 auto IsFreeTruncation = [&](Value *Op) { 147 if (auto *Cast = dyn_cast<CastInst>(Op)) { 148 if (Cast->getParent() == Mul->getParent() && 149 (Cast->getOpcode() == Instruction::SExt || 150 Cast->getOpcode() == Instruction::ZExt) && 151 Cast->getOperand(0)->getType()->getScalarSizeInBits() <= 16) 152 return true; 153 } 154 155 return isa<Constant>(Op); 156 }; 157 158 // If the operation can be freely truncated and has enough sign bits we 159 // can shrink. 160 if (IsFreeTruncation(Op) && 161 ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16) 162 return true; 163 164 // SelectionDAG has limited support for truncating through an add or sub if 165 // the inputs are freely truncatable. 166 if (auto *BO = dyn_cast<BinaryOperator>(Op)) { 167 if (BO->getParent() == Mul->getParent() && 168 IsFreeTruncation(BO->getOperand(0)) && 169 IsFreeTruncation(BO->getOperand(1)) && 170 ComputeNumSignBits(Op, *DL, 0, nullptr, Mul) > 16) 171 return true; 172 } 173 174 return false; 175 }; 176 177 // Both Ops need to be shrinkable. 178 if (!CanShrinkOp(LHS) && !CanShrinkOp(RHS)) 179 return false; 180 181 IRBuilder<> Builder(Mul); 182 183 auto *MulTy = cast<FixedVectorType>(Op->getType()); 184 unsigned NumElts = MulTy->getNumElements(); 185 186 // Extract even elements and odd elements and add them together. This will 187 // be pattern matched by SelectionDAG to pmaddwd. This instruction will be 188 // half the original width. 189 SmallVector<int, 16> EvenMask(NumElts / 2); 190 SmallVector<int, 16> OddMask(NumElts / 2); 191 for (int i = 0, e = NumElts / 2; i != e; ++i) { 192 EvenMask[i] = i * 2; 193 OddMask[i] = i * 2 + 1; 194 } 195 // Creating a new mul so the replaceAllUsesWith below doesn't replace the 196 // uses in the shuffles we're creating. 197 Value *NewMul = Builder.CreateMul(Mul->getOperand(0), Mul->getOperand(1)); 198 Value *EvenElts = Builder.CreateShuffleVector(NewMul, NewMul, EvenMask); 199 Value *OddElts = Builder.CreateShuffleVector(NewMul, NewMul, OddMask); 200 Value *MAdd = Builder.CreateAdd(EvenElts, OddElts); 201 202 // Concatenate zeroes to extend back to the original type. 203 SmallVector<int, 32> ConcatMask(NumElts); 204 std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 205 Value *Zero = Constant::getNullValue(MAdd->getType()); 206 Value *Concat = Builder.CreateShuffleVector(MAdd, Zero, ConcatMask); 207 208 Mul->replaceAllUsesWith(Concat); 209 Mul->eraseFromParent(); 210 211 return true; 212 } 213 214 bool X86PartialReduction::trySADReplacement(Instruction *Op) { 215 if (!ST->hasSSE2()) 216 return false; 217 218 // TODO: There's nothing special about i32, any integer type above i16 should 219 // work just as well. 220 if (!cast<VectorType>(Op->getType())->getElementType()->isIntegerTy(32)) 221 return false; 222 223 // Operand should be a select. 224 auto *SI = dyn_cast<SelectInst>(Op); 225 if (!SI) 226 return false; 227 228 // Select needs to implement absolute value. 229 Value *LHS, *RHS; 230 auto SPR = matchSelectPattern(SI, LHS, RHS); 231 if (SPR.Flavor != SPF_ABS) 232 return false; 233 234 // Need a subtract of two values. 235 auto *Sub = dyn_cast<BinaryOperator>(LHS); 236 if (!Sub || Sub->getOpcode() != Instruction::Sub) 237 return false; 238 239 // Look for zero extend from i8. 240 auto getZeroExtendedVal = [](Value *Op) -> Value * { 241 if (auto *ZExt = dyn_cast<ZExtInst>(Op)) 242 if (cast<VectorType>(ZExt->getOperand(0)->getType()) 243 ->getElementType() 244 ->isIntegerTy(8)) 245 return ZExt->getOperand(0); 246 247 return nullptr; 248 }; 249 250 // Both operands of the subtract should be extends from vXi8. 251 Value *Op0 = getZeroExtendedVal(Sub->getOperand(0)); 252 Value *Op1 = getZeroExtendedVal(Sub->getOperand(1)); 253 if (!Op0 || !Op1) 254 return false; 255 256 IRBuilder<> Builder(SI); 257 258 auto *OpTy = cast<FixedVectorType>(Op->getType()); 259 unsigned NumElts = OpTy->getNumElements(); 260 261 unsigned IntrinsicNumElts; 262 Intrinsic::ID IID; 263 if (ST->hasBWI() && NumElts >= 64) { 264 IID = Intrinsic::x86_avx512_psad_bw_512; 265 IntrinsicNumElts = 64; 266 } else if (ST->hasAVX2() && NumElts >= 32) { 267 IID = Intrinsic::x86_avx2_psad_bw; 268 IntrinsicNumElts = 32; 269 } else { 270 IID = Intrinsic::x86_sse2_psad_bw; 271 IntrinsicNumElts = 16; 272 } 273 274 Function *PSADBWFn = Intrinsic::getDeclaration(SI->getModule(), IID); 275 276 if (NumElts < 16) { 277 // Pad input with zeroes. 278 SmallVector<int, 32> ConcatMask(16); 279 for (unsigned i = 0; i != NumElts; ++i) 280 ConcatMask[i] = i; 281 for (unsigned i = NumElts; i != 16; ++i) 282 ConcatMask[i] = (i % NumElts) + NumElts; 283 284 Value *Zero = Constant::getNullValue(Op0->getType()); 285 Op0 = Builder.CreateShuffleVector(Op0, Zero, ConcatMask); 286 Op1 = Builder.CreateShuffleVector(Op1, Zero, ConcatMask); 287 NumElts = 16; 288 } 289 290 // Intrinsics produce vXi64 and need to be casted to vXi32. 291 auto *I32Ty = 292 FixedVectorType::get(Builder.getInt32Ty(), IntrinsicNumElts / 4); 293 294 assert(NumElts % IntrinsicNumElts == 0 && "Unexpected number of elements!"); 295 unsigned NumSplits = NumElts / IntrinsicNumElts; 296 297 // First collect the pieces we need. 298 SmallVector<Value *, 4> Ops(NumSplits); 299 for (unsigned i = 0; i != NumSplits; ++i) { 300 SmallVector<int, 64> ExtractMask(IntrinsicNumElts); 301 std::iota(ExtractMask.begin(), ExtractMask.end(), i * IntrinsicNumElts); 302 Value *ExtractOp0 = Builder.CreateShuffleVector(Op0, Op0, ExtractMask); 303 Value *ExtractOp1 = Builder.CreateShuffleVector(Op1, Op0, ExtractMask); 304 Ops[i] = Builder.CreateCall(PSADBWFn, {ExtractOp0, ExtractOp1}); 305 Ops[i] = Builder.CreateBitCast(Ops[i], I32Ty); 306 } 307 308 assert(isPowerOf2_32(NumSplits) && "Expected power of 2 splits"); 309 unsigned Stages = Log2_32(NumSplits); 310 for (unsigned s = Stages; s > 0; --s) { 311 unsigned NumConcatElts = 312 cast<FixedVectorType>(Ops[0]->getType())->getNumElements() * 2; 313 for (unsigned i = 0; i != 1U << (s - 1); ++i) { 314 SmallVector<int, 64> ConcatMask(NumConcatElts); 315 std::iota(ConcatMask.begin(), ConcatMask.end(), 0); 316 Ops[i] = Builder.CreateShuffleVector(Ops[i*2], Ops[i*2+1], ConcatMask); 317 } 318 } 319 320 // At this point the final value should be in Ops[0]. Now we need to adjust 321 // it to the final original type. 322 NumElts = cast<FixedVectorType>(OpTy)->getNumElements(); 323 if (NumElts == 2) { 324 // Extract down to 2 elements. 325 Ops[0] = Builder.CreateShuffleVector(Ops[0], Ops[0], ArrayRef<int>{0, 1}); 326 } else if (NumElts >= 8) { 327 SmallVector<int, 32> ConcatMask(NumElts); 328 unsigned SubElts = 329 cast<FixedVectorType>(Ops[0]->getType())->getNumElements(); 330 for (unsigned i = 0; i != SubElts; ++i) 331 ConcatMask[i] = i; 332 for (unsigned i = SubElts; i != NumElts; ++i) 333 ConcatMask[i] = (i % SubElts) + SubElts; 334 335 Value *Zero = Constant::getNullValue(Ops[0]->getType()); 336 Ops[0] = Builder.CreateShuffleVector(Ops[0], Zero, ConcatMask); 337 } 338 339 SI->replaceAllUsesWith(Ops[0]); 340 SI->eraseFromParent(); 341 342 return true; 343 } 344 345 // Walk backwards from the ExtractElementInst and determine if it is the end of 346 // a horizontal reduction. Return the input to the reduction if we find one. 347 static Value *matchAddReduction(const ExtractElementInst &EE, 348 bool &ReduceInOneBB) { 349 ReduceInOneBB = true; 350 // Make sure we're extracting index 0. 351 auto *Index = dyn_cast<ConstantInt>(EE.getIndexOperand()); 352 if (!Index || !Index->isNullValue()) 353 return nullptr; 354 355 const auto *BO = dyn_cast<BinaryOperator>(EE.getVectorOperand()); 356 if (!BO || BO->getOpcode() != Instruction::Add || !BO->hasOneUse()) 357 return nullptr; 358 if (EE.getParent() != BO->getParent()) 359 ReduceInOneBB = false; 360 361 unsigned NumElems = cast<FixedVectorType>(BO->getType())->getNumElements(); 362 // Ensure the reduction size is a power of 2. 363 if (!isPowerOf2_32(NumElems)) 364 return nullptr; 365 366 const Value *Op = BO; 367 unsigned Stages = Log2_32(NumElems); 368 for (unsigned i = 0; i != Stages; ++i) { 369 const auto *BO = dyn_cast<BinaryOperator>(Op); 370 if (!BO || BO->getOpcode() != Instruction::Add) 371 return nullptr; 372 if (EE.getParent() != BO->getParent()) 373 ReduceInOneBB = false; 374 375 // If this isn't the first add, then it should only have 2 users, the 376 // shuffle and another add which we checked in the previous iteration. 377 if (i != 0 && !BO->hasNUses(2)) 378 return nullptr; 379 380 Value *LHS = BO->getOperand(0); 381 Value *RHS = BO->getOperand(1); 382 383 auto *Shuffle = dyn_cast<ShuffleVectorInst>(LHS); 384 if (Shuffle) { 385 Op = RHS; 386 } else { 387 Shuffle = dyn_cast<ShuffleVectorInst>(RHS); 388 Op = LHS; 389 } 390 391 // The first operand of the shuffle should be the same as the other operand 392 // of the bin op. 393 if (!Shuffle || Shuffle->getOperand(0) != Op) 394 return nullptr; 395 396 // Verify the shuffle has the expected (at this stage of the pyramid) mask. 397 unsigned MaskEnd = 1 << i; 398 for (unsigned Index = 0; Index < MaskEnd; ++Index) 399 if (Shuffle->getMaskValue(Index) != (int)(MaskEnd + Index)) 400 return nullptr; 401 } 402 403 return const_cast<Value *>(Op); 404 } 405 406 // See if this BO is reachable from this Phi by walking forward through single 407 // use BinaryOperators with the same opcode. If we get back then we know we've 408 // found a loop and it is safe to step through this Add to find more leaves. 409 static bool isReachableFromPHI(PHINode *Phi, BinaryOperator *BO) { 410 // The PHI itself should only have one use. 411 if (!Phi->hasOneUse()) 412 return false; 413 414 Instruction *U = cast<Instruction>(*Phi->user_begin()); 415 if (U == BO) 416 return true; 417 418 while (U->hasOneUse() && U->getOpcode() == BO->getOpcode()) 419 U = cast<Instruction>(*U->user_begin()); 420 421 return U == BO; 422 } 423 424 // Collect all the leaves of the tree of adds that feeds into the horizontal 425 // reduction. Root is the Value that is used by the horizontal reduction. 426 // We look through single use phis, single use adds, or adds that are used by 427 // a phi that forms a loop with the add. 428 static void collectLeaves(Value *Root, SmallVectorImpl<Instruction *> &Leaves) { 429 SmallPtrSet<Value *, 8> Visited; 430 SmallVector<Value *, 8> Worklist; 431 Worklist.push_back(Root); 432 433 while (!Worklist.empty()) { 434 Value *V = Worklist.pop_back_val(); 435 if (!Visited.insert(V).second) 436 continue; 437 438 if (auto *PN = dyn_cast<PHINode>(V)) { 439 // PHI node should have single use unless it is the root node, then it 440 // has 2 uses. 441 if (!PN->hasNUses(PN == Root ? 2 : 1)) 442 break; 443 444 // Push incoming values to the worklist. 445 append_range(Worklist, PN->incoming_values()); 446 447 continue; 448 } 449 450 if (auto *BO = dyn_cast<BinaryOperator>(V)) { 451 if (BO->getOpcode() == Instruction::Add) { 452 // Simple case. Single use, just push its operands to the worklist. 453 if (BO->hasNUses(BO == Root ? 2 : 1)) { 454 append_range(Worklist, BO->operands()); 455 continue; 456 } 457 458 // If there is additional use, make sure it is an unvisited phi that 459 // gets us back to this node. 460 if (BO->hasNUses(BO == Root ? 3 : 2)) { 461 PHINode *PN = nullptr; 462 for (auto *U : Root->users()) 463 if (auto *P = dyn_cast<PHINode>(U)) 464 if (!Visited.count(P)) 465 PN = P; 466 467 // If we didn't find a 2-input PHI then this isn't a case we can 468 // handle. 469 if (!PN || PN->getNumIncomingValues() != 2) 470 continue; 471 472 // Walk forward from this phi to see if it reaches back to this add. 473 if (!isReachableFromPHI(PN, BO)) 474 continue; 475 476 // The phi forms a loop with this Add, push its operands. 477 append_range(Worklist, BO->operands()); 478 } 479 } 480 } 481 482 // Not an add or phi, make it a leaf. 483 if (auto *I = dyn_cast<Instruction>(V)) { 484 if (!V->hasNUses(I == Root ? 2 : 1)) 485 continue; 486 487 // Add this as a leaf. 488 Leaves.push_back(I); 489 } 490 } 491 } 492 493 bool X86PartialReduction::runOnFunction(Function &F) { 494 if (skipFunction(F)) 495 return false; 496 497 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); 498 if (!TPC) 499 return false; 500 501 auto &TM = TPC->getTM<X86TargetMachine>(); 502 ST = TM.getSubtargetImpl(F); 503 504 DL = &F.getParent()->getDataLayout(); 505 506 bool MadeChange = false; 507 for (auto &BB : F) { 508 for (auto &I : BB) { 509 auto *EE = dyn_cast<ExtractElementInst>(&I); 510 if (!EE) 511 continue; 512 513 bool ReduceInOneBB; 514 // First find a reduction tree. 515 // FIXME: Do we need to handle other opcodes than Add? 516 Value *Root = matchAddReduction(*EE, ReduceInOneBB); 517 if (!Root) 518 continue; 519 520 SmallVector<Instruction *, 8> Leaves; 521 collectLeaves(Root, Leaves); 522 523 for (Instruction *I : Leaves) { 524 if (tryMAddReplacement(I, ReduceInOneBB)) { 525 MadeChange = true; 526 continue; 527 } 528 529 // Don't do SAD matching on the root node. SelectionDAG already 530 // has support for that and currently generates better code. 531 if (I != Root && trySADReplacement(I)) 532 MadeChange = true; 533 } 534 } 535 } 536 537 return MadeChange; 538 } 539