1 //===-- X86FixupVectorConstants.cpp - optimize constant generation -------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file examines all full size vector constant pool loads and attempts to 10 // replace them with smaller constant pool entries, including: 11 // * Converting AVX512 memory-fold instructions to their broadcast-fold form 12 // * TODO: Broadcasting of full width loads. 13 // * TODO: Sign/Zero extension of full width loads. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "X86.h" 18 #include "X86InstrFoldTables.h" 19 #include "X86InstrInfo.h" 20 #include "X86Subtarget.h" 21 #include "llvm/ADT/Statistic.h" 22 #include "llvm/CodeGen/MachineConstantPool.h" 23 24 using namespace llvm; 25 26 #define DEBUG_TYPE "x86-fixup-vector-constants" 27 28 STATISTIC(NumInstChanges, "Number of instructions changes"); 29 30 namespace { 31 class X86FixupVectorConstantsPass : public MachineFunctionPass { 32 public: 33 static char ID; 34 35 X86FixupVectorConstantsPass() : MachineFunctionPass(ID) {} 36 37 StringRef getPassName() const override { 38 return "X86 Fixup Vector Constants"; 39 } 40 41 bool runOnMachineFunction(MachineFunction &MF) override; 42 bool processInstruction(MachineFunction &MF, MachineBasicBlock &MBB, 43 MachineInstr &MI); 44 45 // This pass runs after regalloc and doesn't support VReg operands. 46 MachineFunctionProperties getRequiredProperties() const override { 47 return MachineFunctionProperties().set( 48 MachineFunctionProperties::Property::NoVRegs); 49 } 50 51 private: 52 const X86InstrInfo *TII = nullptr; 53 const X86Subtarget *ST = nullptr; 54 const MCSchedModel *SM = nullptr; 55 }; 56 } // end anonymous namespace 57 58 char X86FixupVectorConstantsPass::ID = 0; 59 60 INITIALIZE_PASS(X86FixupVectorConstantsPass, DEBUG_TYPE, DEBUG_TYPE, false, false) 61 62 FunctionPass *llvm::createX86FixupVectorConstants() { 63 return new X86FixupVectorConstantsPass(); 64 } 65 66 static const Constant *getConstantFromPool(const MachineInstr &MI, 67 const MachineOperand &Op) { 68 if (!Op.isCPI() || Op.getOffset() != 0) 69 return nullptr; 70 71 ArrayRef<MachineConstantPoolEntry> Constants = 72 MI.getParent()->getParent()->getConstantPool()->getConstants(); 73 const MachineConstantPoolEntry &ConstantEntry = Constants[Op.getIndex()]; 74 75 // Bail if this is a machine constant pool entry, we won't be able to dig out 76 // anything useful. 77 if (ConstantEntry.isMachineConstantPoolEntry()) 78 return nullptr; 79 80 return ConstantEntry.Val.ConstVal; 81 } 82 83 // Attempt to extract the full width of bits data from the constant. 84 static std::optional<APInt> extractConstantBits(const Constant *C) { 85 unsigned NumBits = C->getType()->getPrimitiveSizeInBits(); 86 87 if (auto *CInt = dyn_cast<ConstantInt>(C)) 88 return CInt->getValue(); 89 90 if (auto *CFP = dyn_cast<ConstantFP>(C)) 91 return CFP->getValue().bitcastToAPInt(); 92 93 if (auto *CV = dyn_cast<ConstantVector>(C)) { 94 if (auto *CVSplat = CV->getSplatValue(/*AllowUndefs*/ true)) { 95 if (std::optional<APInt> Bits = extractConstantBits(CVSplat)) { 96 assert((NumBits % Bits->getBitWidth()) == 0 && "Illegal splat"); 97 return APInt::getSplat(NumBits, *Bits); 98 } 99 } 100 } 101 102 if (auto *CDS = dyn_cast<ConstantDataSequential>(C)) { 103 bool IsInteger = CDS->getElementType()->isIntegerTy(); 104 bool IsFloat = CDS->getElementType()->isHalfTy() || 105 CDS->getElementType()->isBFloatTy() || 106 CDS->getElementType()->isFloatTy() || 107 CDS->getElementType()->isDoubleTy(); 108 if (IsInteger || IsFloat) { 109 APInt Bits = APInt::getZero(NumBits); 110 unsigned EltBits = CDS->getElementType()->getPrimitiveSizeInBits(); 111 for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) { 112 if (IsInteger) 113 Bits.insertBits(CDS->getElementAsAPInt(I), I * EltBits); 114 else 115 Bits.insertBits(CDS->getElementAsAPFloat(I).bitcastToAPInt(), 116 I * EltBits); 117 } 118 return Bits; 119 } 120 } 121 122 return std::nullopt; 123 } 124 125 // Attempt to compute the splat width of bits data by normalizing the splat to 126 // remove undefs. 127 static std::optional<APInt> getSplatableConstant(const Constant *C, 128 unsigned SplatBitWidth) { 129 const Type *Ty = C->getType(); 130 assert((Ty->getPrimitiveSizeInBits() % SplatBitWidth) == 0 && 131 "Illegal splat width"); 132 133 if (std::optional<APInt> Bits = extractConstantBits(C)) 134 if (Bits->isSplat(SplatBitWidth)) 135 return Bits->trunc(SplatBitWidth); 136 137 // Detect general splats with undefs. 138 // TODO: Do we need to handle NumEltsBits > SplatBitWidth splitting? 139 if (auto *CV = dyn_cast<ConstantVector>(C)) { 140 unsigned NumOps = CV->getNumOperands(); 141 unsigned NumEltsBits = Ty->getScalarSizeInBits(); 142 unsigned NumScaleOps = SplatBitWidth / NumEltsBits; 143 if ((SplatBitWidth % NumEltsBits) == 0) { 144 // Collect the elements and ensure that within the repeated splat sequence 145 // they either match or are undef. 146 SmallVector<Constant *, 16> Sequence(NumScaleOps, nullptr); 147 for (unsigned Idx = 0; Idx != NumOps; ++Idx) { 148 if (Constant *Elt = CV->getAggregateElement(Idx)) { 149 if (isa<UndefValue>(Elt)) 150 continue; 151 unsigned SplatIdx = Idx % NumScaleOps; 152 if (!Sequence[SplatIdx] || Sequence[SplatIdx] == Elt) { 153 Sequence[SplatIdx] = Elt; 154 continue; 155 } 156 } 157 return std::nullopt; 158 } 159 // Extract the constant bits forming the splat and insert into the bits 160 // data, leave undef as zero. 161 APInt SplatBits = APInt::getZero(SplatBitWidth); 162 for (unsigned I = 0; I != NumScaleOps; ++I) { 163 if (!Sequence[I]) 164 continue; 165 if (std::optional<APInt> Bits = extractConstantBits(Sequence[I])) { 166 SplatBits.insertBits(*Bits, I * Bits->getBitWidth()); 167 continue; 168 } 169 return std::nullopt; 170 } 171 return SplatBits; 172 } 173 } 174 175 return std::nullopt; 176 } 177 178 // Attempt to rebuild a normalized splat vector constant of the requested splat 179 // width, built up of potentially smaller scalar values. 180 // NOTE: We don't always bother converting to scalars if the vector length is 1. 181 static Constant *rebuildSplatableConstant(const Constant *C, 182 unsigned SplatBitWidth) { 183 std::optional<APInt> Splat = getSplatableConstant(C, SplatBitWidth); 184 if (!Splat) 185 return nullptr; 186 187 // Determine scalar size to use for the constant splat vector, clamping as we 188 // might have found a splat smaller than the original constant data. 189 const Type *OriginalType = C->getType(); 190 Type *SclTy = OriginalType->getScalarType(); 191 unsigned NumSclBits = SclTy->getPrimitiveSizeInBits(); 192 NumSclBits = std::min<unsigned>(NumSclBits, SplatBitWidth); 193 194 if (NumSclBits == 8) { 195 SmallVector<uint8_t> RawBits; 196 for (unsigned I = 0; I != SplatBitWidth; I += 8) 197 RawBits.push_back(Splat->extractBits(8, I).getZExtValue()); 198 return ConstantDataVector::get(OriginalType->getContext(), RawBits); 199 } 200 201 if (NumSclBits == 16) { 202 SmallVector<uint16_t> RawBits; 203 for (unsigned I = 0; I != SplatBitWidth; I += 16) 204 RawBits.push_back(Splat->extractBits(16, I).getZExtValue()); 205 if (SclTy->is16bitFPTy()) 206 return ConstantDataVector::getFP(SclTy, RawBits); 207 return ConstantDataVector::get(OriginalType->getContext(), RawBits); 208 } 209 210 if (NumSclBits == 32) { 211 SmallVector<uint32_t> RawBits; 212 for (unsigned I = 0; I != SplatBitWidth; I += 32) 213 RawBits.push_back(Splat->extractBits(32, I).getZExtValue()); 214 if (SclTy->isFloatTy()) 215 return ConstantDataVector::getFP(SclTy, RawBits); 216 return ConstantDataVector::get(OriginalType->getContext(), RawBits); 217 } 218 219 // Fallback to i64 / double. 220 SmallVector<uint64_t> RawBits; 221 for (unsigned I = 0; I != SplatBitWidth; I += 64) 222 RawBits.push_back(Splat->extractBits(64, I).getZExtValue()); 223 if (SclTy->isDoubleTy()) 224 return ConstantDataVector::getFP(SclTy, RawBits); 225 return ConstantDataVector::get(OriginalType->getContext(), RawBits); 226 } 227 228 bool X86FixupVectorConstantsPass::processInstruction(MachineFunction &MF, 229 MachineBasicBlock &MBB, 230 MachineInstr &MI) { 231 unsigned Opc = MI.getOpcode(); 232 MachineConstantPool *CP = MI.getParent()->getParent()->getConstantPool(); 233 bool HasDQI = ST->hasDQI(); 234 bool HasBWI = ST->hasBWI(); 235 236 auto ConvertToBroadcast = [&](unsigned OpBcst256, unsigned OpBcst128, 237 unsigned OpBcst64, unsigned OpBcst32, 238 unsigned OpBcst16, unsigned OpBcst8, 239 unsigned OperandNo) { 240 assert(MI.getNumOperands() >= (OperandNo + X86::AddrNumOperands) && 241 "Unexpected number of operands!"); 242 243 MachineOperand &CstOp = MI.getOperand(OperandNo + X86::AddrDisp); 244 if (auto *C = getConstantFromPool(MI, CstOp)) { 245 // Attempt to detect a suitable splat from increasing splat widths. 246 std::pair<unsigned, unsigned> Broadcasts[] = { 247 {8, OpBcst8}, {16, OpBcst16}, {32, OpBcst32}, 248 {64, OpBcst64}, {128, OpBcst128}, {256, OpBcst256}, 249 }; 250 for (auto [BitWidth, OpBcst] : Broadcasts) { 251 if (OpBcst) { 252 // Construct a suitable splat constant and adjust the MI to 253 // use the new constant pool entry. 254 if (Constant *NewCst = rebuildSplatableConstant(C, BitWidth)) { 255 unsigned NewCPI = 256 CP->getConstantPoolIndex(NewCst, Align(BitWidth / 8)); 257 MI.setDesc(TII->get(OpBcst)); 258 CstOp.setIndex(NewCPI); 259 return true; 260 } 261 } 262 } 263 } 264 return false; 265 }; 266 267 // Attempt to convert full width vector loads into broadcast loads. 268 switch (Opc) { 269 /* FP Loads */ 270 case X86::MOVAPDrm: 271 case X86::MOVAPSrm: 272 case X86::MOVUPDrm: 273 case X86::MOVUPSrm: 274 // TODO: SSE3 MOVDDUP Handling 275 return false; 276 case X86::VMOVAPDrm: 277 case X86::VMOVAPSrm: 278 case X86::VMOVUPDrm: 279 case X86::VMOVUPSrm: 280 return ConvertToBroadcast(0, 0, X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0, 0, 281 1); 282 case X86::VMOVAPDYrm: 283 case X86::VMOVAPSYrm: 284 case X86::VMOVUPDYrm: 285 case X86::VMOVUPSYrm: 286 return ConvertToBroadcast(0, X86::VBROADCASTF128, X86::VBROADCASTSDYrm, 287 X86::VBROADCASTSSYrm, 0, 0, 1); 288 case X86::VMOVAPDZ128rm: 289 case X86::VMOVAPSZ128rm: 290 case X86::VMOVUPDZ128rm: 291 case X86::VMOVUPSZ128rm: 292 return ConvertToBroadcast(0, 0, X86::VMOVDDUPZ128rm, 293 X86::VBROADCASTSSZ128rm, 0, 0, 1); 294 case X86::VMOVAPDZ256rm: 295 case X86::VMOVAPSZ256rm: 296 case X86::VMOVUPDZ256rm: 297 case X86::VMOVUPSZ256rm: 298 return ConvertToBroadcast( 299 0, HasDQI ? X86::VBROADCASTF64X2Z128rm : X86::VBROADCASTF32X4Z256rm, 300 X86::VBROADCASTSDZ256rm, X86::VBROADCASTSSZ256rm, 0, 0, 1); 301 case X86::VMOVAPDZrm: 302 case X86::VMOVAPSZrm: 303 case X86::VMOVUPDZrm: 304 case X86::VMOVUPSZrm: 305 return ConvertToBroadcast( 306 HasDQI ? X86::VBROADCASTF32X8rm : X86::VBROADCASTF64X4rm, 307 HasDQI ? X86::VBROADCASTF64X2rm : X86::VBROADCASTF32X4rm, 308 X86::VBROADCASTSDZrm, X86::VBROADCASTSSZrm, 0, 0, 1); 309 /* Integer Loads */ 310 case X86::VMOVDQArm: 311 case X86::VMOVDQUrm: 312 if (ST->hasAVX2()) 313 return ConvertToBroadcast(0, 0, X86::VPBROADCASTQrm, X86::VPBROADCASTDrm, 314 X86::VPBROADCASTWrm, X86::VPBROADCASTBrm, 1); 315 return ConvertToBroadcast(0, 0, X86::VMOVDDUPrm, X86::VBROADCASTSSrm, 0, 0, 316 1); 317 case X86::VMOVDQAYrm: 318 case X86::VMOVDQUYrm: 319 if (ST->hasAVX2()) 320 return ConvertToBroadcast(0, X86::VBROADCASTI128, X86::VPBROADCASTQYrm, 321 X86::VPBROADCASTDYrm, X86::VPBROADCASTWYrm, 322 X86::VPBROADCASTBYrm, 1); 323 return ConvertToBroadcast(0, X86::VBROADCASTF128, X86::VBROADCASTSDYrm, 324 X86::VBROADCASTSSYrm, 0, 0, 1); 325 case X86::VMOVDQA32Z128rm: 326 case X86::VMOVDQA64Z128rm: 327 case X86::VMOVDQU32Z128rm: 328 case X86::VMOVDQU64Z128rm: 329 return ConvertToBroadcast(0, 0, X86::VPBROADCASTQZ128rm, 330 X86::VPBROADCASTDZ128rm, 331 HasBWI ? X86::VPBROADCASTWZ128rm : 0, 332 HasBWI ? X86::VPBROADCASTBZ128rm : 0, 1); 333 case X86::VMOVDQA32Z256rm: 334 case X86::VMOVDQA64Z256rm: 335 case X86::VMOVDQU32Z256rm: 336 case X86::VMOVDQU64Z256rm: 337 return ConvertToBroadcast( 338 0, HasDQI ? X86::VBROADCASTI64X2Z128rm : X86::VBROADCASTI32X4Z256rm, 339 X86::VPBROADCASTQZ256rm, X86::VPBROADCASTDZ256rm, 340 HasBWI ? X86::VPBROADCASTWZ256rm : 0, 341 HasBWI ? X86::VPBROADCASTBZ256rm : 0, 1); 342 case X86::VMOVDQA32Zrm: 343 case X86::VMOVDQA64Zrm: 344 case X86::VMOVDQU32Zrm: 345 case X86::VMOVDQU64Zrm: 346 return ConvertToBroadcast( 347 HasDQI ? X86::VBROADCASTI32X8rm : X86::VBROADCASTI64X4rm, 348 HasDQI ? X86::VBROADCASTI64X2rm : X86::VBROADCASTI32X4rm, 349 X86::VPBROADCASTQZrm, X86::VPBROADCASTDZrm, 350 HasBWI ? X86::VPBROADCASTWZrm : 0, HasBWI ? X86::VPBROADCASTBZrm : 0, 351 1); 352 } 353 354 // Attempt to find a AVX512 mapping from a full width memory-fold instruction 355 // to a broadcast-fold instruction variant. 356 if ((MI.getDesc().TSFlags & X86II::EncodingMask) == X86II::EVEX) { 357 unsigned OpBcst32 = 0, OpBcst64 = 0; 358 unsigned OpNoBcst32 = 0, OpNoBcst64 = 0; 359 if (const X86MemoryFoldTableEntry *Mem2Bcst = 360 llvm::lookupBroadcastFoldTable(Opc, 32)) { 361 OpBcst32 = Mem2Bcst->DstOp; 362 OpNoBcst32 = Mem2Bcst->Flags & TB_INDEX_MASK; 363 } 364 if (const X86MemoryFoldTableEntry *Mem2Bcst = 365 llvm::lookupBroadcastFoldTable(Opc, 64)) { 366 OpBcst64 = Mem2Bcst->DstOp; 367 OpNoBcst64 = Mem2Bcst->Flags & TB_INDEX_MASK; 368 } 369 assert(((OpBcst32 == 0) || (OpBcst64 == 0) || (OpNoBcst32 == OpNoBcst64)) && 370 "OperandNo mismatch"); 371 372 if (OpBcst32 || OpBcst64) { 373 unsigned OpNo = OpBcst32 == 0 ? OpNoBcst64 : OpNoBcst32; 374 return ConvertToBroadcast(0, 0, OpBcst64, OpBcst32, 0, 0, OpNo); 375 } 376 } 377 378 return false; 379 } 380 381 bool X86FixupVectorConstantsPass::runOnMachineFunction(MachineFunction &MF) { 382 LLVM_DEBUG(dbgs() << "Start X86FixupVectorConstants\n";); 383 bool Changed = false; 384 ST = &MF.getSubtarget<X86Subtarget>(); 385 TII = ST->getInstrInfo(); 386 SM = &ST->getSchedModel(); 387 388 for (MachineBasicBlock &MBB : MF) { 389 for (MachineInstr &MI : MBB) { 390 if (processInstruction(MF, MBB, MI)) { 391 ++NumInstChanges; 392 Changed = true; 393 } 394 } 395 } 396 LLVM_DEBUG(dbgs() << "End X86FixupVectorConstants\n";); 397 return Changed; 398 } 399