1 //===- CombinerHelperVectorOps.cpp-----------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements CombinerHelper for G_EXTRACT_VECTOR_ELT, 10 // G_INSERT_VECTOR_ELT, and G_VSCALE 11 // 12 //===----------------------------------------------------------------------===// 13 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 14 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" 15 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" 16 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" 17 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 19 #include "llvm/CodeGen/GlobalISel/Utils.h" 20 #include "llvm/CodeGen/LowLevelTypeUtils.h" 21 #include "llvm/CodeGen/MachineOperand.h" 22 #include "llvm/CodeGen/MachineRegisterInfo.h" 23 #include "llvm/CodeGen/TargetLowering.h" 24 #include "llvm/CodeGen/TargetOpcodes.h" 25 #include "llvm/Support/Casting.h" 26 #include <optional> 27 28 #define DEBUG_TYPE "gi-combiner" 29 30 using namespace llvm; 31 using namespace MIPatternMatch; 32 33 bool CombinerHelper::matchExtractVectorElement(MachineInstr &MI, 34 BuildFnTy &MatchInfo) { 35 GExtractVectorElement *Extract = cast<GExtractVectorElement>(&MI); 36 37 Register Dst = Extract->getReg(0); 38 Register Vector = Extract->getVectorReg(); 39 Register Index = Extract->getIndexReg(); 40 LLT DstTy = MRI.getType(Dst); 41 LLT VectorTy = MRI.getType(Vector); 42 43 // The vector register can be def'd by various ops that have vector as its 44 // type. They can all be used for constant folding, scalarizing, 45 // canonicalization, or combining based on symmetry. 46 // 47 // vector like ops 48 // * build vector 49 // * build vector trunc 50 // * shuffle vector 51 // * splat vector 52 // * concat vectors 53 // * insert/extract vector element 54 // * insert/extract subvector 55 // * vector loads 56 // * scalable vector loads 57 // 58 // compute like ops 59 // * binary ops 60 // * unary ops 61 // * exts and truncs 62 // * casts 63 // * fneg 64 // * select 65 // * phis 66 // * cmps 67 // * freeze 68 // * bitcast 69 // * undef 70 71 // We try to get the value of the Index register. 72 std::optional<ValueAndVReg> MaybeIndex = 73 getIConstantVRegValWithLookThrough(Index, MRI); 74 std::optional<APInt> IndexC = std::nullopt; 75 76 if (MaybeIndex) 77 IndexC = MaybeIndex->Value; 78 79 // Fold extractVectorElement(Vector, TOOLARGE) -> undef 80 if (IndexC && VectorTy.isFixedVector() && 81 IndexC->uge(VectorTy.getNumElements()) && 82 isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) { 83 // For fixed-length vectors, it's invalid to extract out-of-range elements. 84 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); }; 85 return true; 86 } 87 88 return false; 89 } 90 91 bool CombinerHelper::matchExtractVectorElementWithDifferentIndices( 92 const MachineOperand &MO, BuildFnTy &MatchInfo) { 93 MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI); 94 GExtractVectorElement *Extract = cast<GExtractVectorElement>(Root); 95 96 // 97 // %idx1:_(s64) = G_CONSTANT i64 1 98 // %idx2:_(s64) = G_CONSTANT i64 2 99 // %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>), 100 // %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %insert(<2 101 // x s32>), %idx1(s64) 102 // 103 // --> 104 // 105 // %insert:_(<2 x s32>) = G_INSERT_VECTOR_ELT_ELT %bv(<2 x s32>), 106 // %value(s32), %idx2(s64) %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x 107 // s32>), %idx1(s64) 108 // 109 // 110 111 Register Index = Extract->getIndexReg(); 112 113 // We try to get the value of the Index register. 114 std::optional<ValueAndVReg> MaybeIndex = 115 getIConstantVRegValWithLookThrough(Index, MRI); 116 std::optional<APInt> IndexC = std::nullopt; 117 118 if (!MaybeIndex) 119 return false; 120 else 121 IndexC = MaybeIndex->Value; 122 123 Register Vector = Extract->getVectorReg(); 124 125 GInsertVectorElement *Insert = 126 getOpcodeDef<GInsertVectorElement>(Vector, MRI); 127 if (!Insert) 128 return false; 129 130 Register Dst = Extract->getReg(0); 131 132 std::optional<ValueAndVReg> MaybeInsertIndex = 133 getIConstantVRegValWithLookThrough(Insert->getIndexReg(), MRI); 134 135 if (MaybeInsertIndex && MaybeInsertIndex->Value != *IndexC) { 136 // There is no one-use check. We have to keep the insert. When both Index 137 // registers are constants and not equal, we can look into the Vector 138 // register of the insert. 139 MatchInfo = [=](MachineIRBuilder &B) { 140 B.buildExtractVectorElement(Dst, Insert->getVectorReg(), Index); 141 }; 142 return true; 143 } 144 145 return false; 146 } 147 148 bool CombinerHelper::matchExtractVectorElementWithBuildVector( 149 const MachineOperand &MO, BuildFnTy &MatchInfo) { 150 MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI); 151 GExtractVectorElement *Extract = cast<GExtractVectorElement>(Root); 152 153 // 154 // %zero:_(s64) = G_CONSTANT i64 0 155 // %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32) 156 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64) 157 // 158 // --> 159 // 160 // %extract:_(32) = COPY %arg1(s32) 161 // 162 // 163 // 164 // %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32) 165 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64) 166 // 167 // --> 168 // 169 // %bv:_(<2 x s32>) = G_BUILD_VECTOR %arg1(s32), %arg2(s32) 170 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64) 171 // 172 173 Register Vector = Extract->getVectorReg(); 174 175 // We expect a buildVector on the Vector register. 176 GBuildVector *Build = getOpcodeDef<GBuildVector>(Vector, MRI); 177 if (!Build) 178 return false; 179 180 LLT VectorTy = MRI.getType(Vector); 181 182 // There is a one-use check. There are more combines on build vectors. 183 EVT Ty(getMVTForLLT(VectorTy)); 184 if (!MRI.hasOneNonDBGUse(Build->getReg(0)) || 185 !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty)) 186 return false; 187 188 Register Index = Extract->getIndexReg(); 189 190 // If the Index is constant, then we can extract the element from the given 191 // offset. 192 std::optional<ValueAndVReg> MaybeIndex = 193 getIConstantVRegValWithLookThrough(Index, MRI); 194 if (!MaybeIndex) 195 return false; 196 197 // We now know that there is a buildVector def'd on the Vector register and 198 // the index is const. The combine will succeed. 199 200 Register Dst = Extract->getReg(0); 201 202 MatchInfo = [=](MachineIRBuilder &B) { 203 B.buildCopy(Dst, Build->getSourceReg(MaybeIndex->Value.getZExtValue())); 204 }; 205 206 return true; 207 } 208 209 bool CombinerHelper::matchExtractVectorElementWithBuildVectorTrunc( 210 const MachineOperand &MO, BuildFnTy &MatchInfo) { 211 MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI); 212 GExtractVectorElement *Extract = cast<GExtractVectorElement>(Root); 213 214 // 215 // %zero:_(s64) = G_CONSTANT i64 0 216 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64) 217 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %zero(s64) 218 // 219 // --> 220 // 221 // %extract:_(32) = G_TRUNC %arg1(s64) 222 // 223 // 224 // 225 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64) 226 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64) 227 // 228 // --> 229 // 230 // %bv:_(<2 x s32>) = G_BUILD_VECTOR_TRUNC %arg1(s64), %arg2(s64) 231 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %bv(<2 x s32>), %opaque(s64) 232 // 233 234 Register Vector = Extract->getVectorReg(); 235 236 // We expect a buildVectorTrunc on the Vector register. 237 GBuildVectorTrunc *Build = getOpcodeDef<GBuildVectorTrunc>(Vector, MRI); 238 if (!Build) 239 return false; 240 241 LLT VectorTy = MRI.getType(Vector); 242 243 // There is a one-use check. There are more combines on build vectors. 244 EVT Ty(getMVTForLLT(VectorTy)); 245 if (!MRI.hasOneNonDBGUse(Build->getReg(0)) || 246 !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty)) 247 return false; 248 249 Register Index = Extract->getIndexReg(); 250 251 // If the Index is constant, then we can extract the element from the given 252 // offset. 253 std::optional<ValueAndVReg> MaybeIndex = 254 getIConstantVRegValWithLookThrough(Index, MRI); 255 if (!MaybeIndex) 256 return false; 257 258 // We now know that there is a buildVectorTrunc def'd on the Vector register 259 // and the index is const. The combine will succeed. 260 261 Register Dst = Extract->getReg(0); 262 LLT DstTy = MRI.getType(Dst); 263 LLT SrcTy = MRI.getType(Build->getSourceReg(0)); 264 265 // For buildVectorTrunc, the inputs are truncated. 266 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) 267 return false; 268 269 MatchInfo = [=](MachineIRBuilder &B) { 270 B.buildTrunc(Dst, Build->getSourceReg(MaybeIndex->Value.getZExtValue())); 271 }; 272 273 return true; 274 } 275 276 bool CombinerHelper::matchExtractVectorElementWithShuffleVector( 277 const MachineOperand &MO, BuildFnTy &MatchInfo) { 278 GExtractVectorElement *Extract = 279 cast<GExtractVectorElement>(getDefIgnoringCopies(MO.getReg(), MRI)); 280 281 // 282 // %zero:_(s64) = G_CONSTANT i64 0 283 // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>), 284 // shufflemask(0, 0, 0, 0) 285 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %zero(s64) 286 // 287 // --> 288 // 289 // %zero1:_(s64) = G_CONSTANT i64 0 290 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %arg1(<4 x s32>), %zero1(s64) 291 // 292 // 293 // 294 // 295 // %three:_(s64) = G_CONSTANT i64 3 296 // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>), 297 // shufflemask(0, 0, 0, -1) 298 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %three(s64) 299 // 300 // --> 301 // 302 // %extract:_(s32) = G_IMPLICIT_DEF 303 // 304 // 305 // 306 // 307 // 308 // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>), 309 // shufflemask(0, 0, 0, -1) 310 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %opaque(s64) 311 // 312 // --> 313 // 314 // %sv:_(<4 x s32>) = G_SHUFFLE_SHUFFLE %arg1(<4 x s32>), %arg2(<4 x s32>), 315 // shufflemask(0, 0, 0, -1) 316 // %extract:_(s32) = G_EXTRACT_VECTOR_ELT %sv(<4 x s32>), %opaque(s64) 317 // 318 319 // We try to get the value of the Index register. 320 std::optional<ValueAndVReg> MaybeIndex = 321 getIConstantVRegValWithLookThrough(Extract->getIndexReg(), MRI); 322 if (!MaybeIndex) 323 return false; 324 325 GShuffleVector *Shuffle = 326 cast<GShuffleVector>(getDefIgnoringCopies(Extract->getVectorReg(), MRI)); 327 328 ArrayRef<int> Mask = Shuffle->getMask(); 329 330 unsigned Offset = MaybeIndex->Value.getZExtValue(); 331 int SrcIdx = Mask[Offset]; 332 333 LLT Src1Type = MRI.getType(Shuffle->getSrc1Reg()); 334 // At the IR level a <1 x ty> shuffle vector is valid, but we want to extract 335 // from a vector. 336 assert(Src1Type.isVector() && "expected to extract from a vector"); 337 unsigned LHSWidth = Src1Type.isVector() ? Src1Type.getNumElements() : 1; 338 339 // Note that there is no one use check. 340 Register Dst = Extract->getReg(0); 341 LLT DstTy = MRI.getType(Dst); 342 343 if (SrcIdx < 0 && 344 isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) { 345 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); }; 346 return true; 347 } 348 349 // If the legality check failed, then we still have to abort. 350 if (SrcIdx < 0) 351 return false; 352 353 Register NewVector; 354 355 // We check in which vector and at what offset to look through. 356 if (SrcIdx < (int)LHSWidth) { 357 NewVector = Shuffle->getSrc1Reg(); 358 // SrcIdx unchanged 359 } else { // SrcIdx >= LHSWidth 360 NewVector = Shuffle->getSrc2Reg(); 361 SrcIdx -= LHSWidth; 362 } 363 364 LLT IdxTy = MRI.getType(Extract->getIndexReg()); 365 LLT NewVectorTy = MRI.getType(NewVector); 366 367 // We check the legality of the look through. 368 if (!isLegalOrBeforeLegalizer( 369 {TargetOpcode::G_EXTRACT_VECTOR_ELT, {DstTy, NewVectorTy, IdxTy}}) || 370 !isConstantLegalOrBeforeLegalizer({IdxTy})) 371 return false; 372 373 // We look through the shuffle vector. 374 MatchInfo = [=](MachineIRBuilder &B) { 375 auto Idx = B.buildConstant(IdxTy, SrcIdx); 376 B.buildExtractVectorElement(Dst, NewVector, Idx); 377 }; 378 379 return true; 380 } 381 382 bool CombinerHelper::matchInsertVectorElementOOB(MachineInstr &MI, 383 BuildFnTy &MatchInfo) { 384 GInsertVectorElement *Insert = cast<GInsertVectorElement>(&MI); 385 386 Register Dst = Insert->getReg(0); 387 LLT DstTy = MRI.getType(Dst); 388 Register Index = Insert->getIndexReg(); 389 390 if (!DstTy.isFixedVector()) 391 return false; 392 393 std::optional<ValueAndVReg> MaybeIndex = 394 getIConstantVRegValWithLookThrough(Index, MRI); 395 396 if (MaybeIndex && MaybeIndex->Value.uge(DstTy.getNumElements()) && 397 isLegalOrBeforeLegalizer({TargetOpcode::G_IMPLICIT_DEF, {DstTy}})) { 398 MatchInfo = [=](MachineIRBuilder &B) { B.buildUndef(Dst); }; 399 return true; 400 } 401 402 return false; 403 } 404 405 bool CombinerHelper::matchAddOfVScale(const MachineOperand &MO, 406 BuildFnTy &MatchInfo) { 407 GAdd *Add = cast<GAdd>(MRI.getVRegDef(MO.getReg())); 408 GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Add->getLHSReg())); 409 GVScale *RHSVScale = cast<GVScale>(MRI.getVRegDef(Add->getRHSReg())); 410 411 Register Dst = Add->getReg(0); 412 413 if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)) || 414 !MRI.hasOneNonDBGUse(RHSVScale->getReg(0))) 415 return false; 416 417 MatchInfo = [=](MachineIRBuilder &B) { 418 B.buildVScale(Dst, LHSVScale->getSrc() + RHSVScale->getSrc()); 419 }; 420 421 return true; 422 } 423 424 bool CombinerHelper::matchMulOfVScale(const MachineOperand &MO, 425 BuildFnTy &MatchInfo) { 426 GMul *Mul = cast<GMul>(MRI.getVRegDef(MO.getReg())); 427 GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Mul->getLHSReg())); 428 429 std::optional<APInt> MaybeRHS = getIConstantVRegVal(Mul->getRHSReg(), MRI); 430 if (!MaybeRHS) 431 return false; 432 433 Register Dst = MO.getReg(); 434 435 if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0))) 436 return false; 437 438 MatchInfo = [=](MachineIRBuilder &B) { 439 B.buildVScale(Dst, LHSVScale->getSrc() * *MaybeRHS); 440 }; 441 442 return true; 443 } 444 445 bool CombinerHelper::matchSubOfVScale(const MachineOperand &MO, 446 BuildFnTy &MatchInfo) { 447 GSub *Sub = cast<GSub>(MRI.getVRegDef(MO.getReg())); 448 GVScale *RHSVScale = cast<GVScale>(MRI.getVRegDef(Sub->getRHSReg())); 449 450 Register Dst = MO.getReg(); 451 LLT DstTy = MRI.getType(Dst); 452 453 if (!MRI.hasOneNonDBGUse(RHSVScale->getReg(0)) || 454 !isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, DstTy})) 455 return false; 456 457 MatchInfo = [=](MachineIRBuilder &B) { 458 auto VScale = B.buildVScale(DstTy, -RHSVScale->getSrc()); 459 B.buildAdd(Dst, Sub->getLHSReg(), VScale, Sub->getFlags()); 460 }; 461 462 return true; 463 } 464 465 bool CombinerHelper::matchShlOfVScale(const MachineOperand &MO, 466 BuildFnTy &MatchInfo) { 467 GShl *Shl = cast<GShl>(MRI.getVRegDef(MO.getReg())); 468 GVScale *LHSVScale = cast<GVScale>(MRI.getVRegDef(Shl->getSrcReg())); 469 470 std::optional<APInt> MaybeRHS = getIConstantVRegVal(Shl->getShiftReg(), MRI); 471 if (!MaybeRHS) 472 return false; 473 474 Register Dst = MO.getReg(); 475 LLT DstTy = MRI.getType(Dst); 476 477 if (!MRI.hasOneNonDBGUse(LHSVScale->getReg(0)) || 478 !isLegalOrBeforeLegalizer({TargetOpcode::G_VSCALE, DstTy})) 479 return false; 480 481 MatchInfo = [=](MachineIRBuilder &B) { 482 B.buildVScale(Dst, LHSVScale->getSrc().shl(*MaybeRHS)); 483 }; 484 485 return true; 486 } 487