1 //===- CombinerHelperCasts.cpp---------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements CombinerHelper for G_ANYEXT, G_SEXT, G_TRUNC, and 10 // G_ZEXT 11 // 12 //===----------------------------------------------------------------------===// 13 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 14 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" 15 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" 16 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 17 #include "llvm/CodeGen/GlobalISel/Utils.h" 18 #include "llvm/CodeGen/LowLevelTypeUtils.h" 19 #include "llvm/CodeGen/MachineOperand.h" 20 #include "llvm/CodeGen/MachineRegisterInfo.h" 21 #include "llvm/CodeGen/TargetOpcodes.h" 22 #include "llvm/Support/Casting.h" 23 24 #define DEBUG_TYPE "gi-combiner" 25 26 using namespace llvm; 27 28 bool CombinerHelper::matchSextOfTrunc(const MachineOperand &MO, 29 BuildFnTy &MatchInfo) const { 30 GSext *Sext = cast<GSext>(getDefIgnoringCopies(MO.getReg(), MRI)); 31 GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Sext->getSrcReg(), MRI)); 32 33 Register Dst = Sext->getReg(0); 34 Register Src = Trunc->getSrcReg(); 35 36 LLT DstTy = MRI.getType(Dst); 37 LLT SrcTy = MRI.getType(Src); 38 39 // Combines without nsw trunc. 40 if (!Trunc->getFlag(MachineInstr::NoSWrap)) { 41 if (DstTy != SrcTy || 42 !isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT_INREG, {DstTy, SrcTy}})) 43 return false; 44 45 // Do this for 8 bit values and up. We don't want to do it for e.g. G_TRUNC 46 // to i1. 47 unsigned TruncWidth = MRI.getType(Trunc->getReg(0)).getScalarSizeInBits(); 48 if (TruncWidth < 8) 49 return false; 50 51 MatchInfo = [=](MachineIRBuilder &B) { 52 B.buildSExtInReg(Dst, Src, TruncWidth); 53 }; 54 return true; 55 } 56 57 // Combines for nsw trunc. 58 59 if (DstTy == SrcTy) { 60 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); }; 61 return true; 62 } 63 64 if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() && 65 isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) { 66 MatchInfo = [=](MachineIRBuilder &B) { 67 B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoSWrap); 68 }; 69 return true; 70 } 71 72 if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() && 73 isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}})) { 74 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; 75 return true; 76 } 77 78 return false; 79 } 80 81 bool CombinerHelper::matchZextOfTrunc(const MachineOperand &MO, 82 BuildFnTy &MatchInfo) const { 83 GZext *Zext = cast<GZext>(getDefIgnoringCopies(MO.getReg(), MRI)); 84 GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Zext->getSrcReg(), MRI)); 85 86 Register Dst = Zext->getReg(0); 87 Register Src = Trunc->getSrcReg(); 88 89 LLT DstTy = MRI.getType(Dst); 90 LLT SrcTy = MRI.getType(Src); 91 92 if (DstTy == SrcTy) { 93 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); }; 94 return true; 95 } 96 97 if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() && 98 isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) { 99 MatchInfo = [=](MachineIRBuilder &B) { 100 B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoUWrap); 101 }; 102 return true; 103 } 104 105 if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() && 106 isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {DstTy, SrcTy}})) { 107 MatchInfo = [=](MachineIRBuilder &B) { 108 B.buildZExt(Dst, Src, MachineInstr::MIFlag::NonNeg); 109 }; 110 return true; 111 } 112 113 return false; 114 } 115 116 bool CombinerHelper::matchNonNegZext(const MachineOperand &MO, 117 BuildFnTy &MatchInfo) const { 118 GZext *Zext = cast<GZext>(MRI.getVRegDef(MO.getReg())); 119 120 Register Dst = Zext->getReg(0); 121 Register Src = Zext->getSrcReg(); 122 123 LLT DstTy = MRI.getType(Dst); 124 LLT SrcTy = MRI.getType(Src); 125 const auto &TLI = getTargetLowering(); 126 127 // Convert zext nneg to sext if sext is the preferred form for the target. 128 if (isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}}) && 129 TLI.isSExtCheaperThanZExt(getMVTForLLT(SrcTy), getMVTForLLT(DstTy))) { 130 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; 131 return true; 132 } 133 134 return false; 135 } 136 137 bool CombinerHelper::matchTruncateOfExt(const MachineInstr &Root, 138 const MachineInstr &ExtMI, 139 BuildFnTy &MatchInfo) const { 140 const GTrunc *Trunc = cast<GTrunc>(&Root); 141 const GExtOp *Ext = cast<GExtOp>(&ExtMI); 142 143 if (!MRI.hasOneNonDBGUse(Ext->getReg(0))) 144 return false; 145 146 Register Dst = Trunc->getReg(0); 147 Register Src = Ext->getSrcReg(); 148 LLT DstTy = MRI.getType(Dst); 149 LLT SrcTy = MRI.getType(Src); 150 151 if (SrcTy == DstTy) { 152 // The source and the destination are equally sized. We need to copy. 153 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); }; 154 155 return true; 156 } 157 158 if (SrcTy.getScalarSizeInBits() < DstTy.getScalarSizeInBits()) { 159 // If the source is smaller than the destination, we need to extend. 160 161 if (!isLegalOrBeforeLegalizer({Ext->getOpcode(), {DstTy, SrcTy}})) 162 return false; 163 164 MatchInfo = [=](MachineIRBuilder &B) { 165 B.buildInstr(Ext->getOpcode(), {Dst}, {Src}); 166 }; 167 168 return true; 169 } 170 171 if (SrcTy.getScalarSizeInBits() > DstTy.getScalarSizeInBits()) { 172 // If the source is larger than the destination, then we need to truncate. 173 174 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) 175 return false; 176 177 MatchInfo = [=](MachineIRBuilder &B) { B.buildTrunc(Dst, Src); }; 178 179 return true; 180 } 181 182 return false; 183 } 184 185 bool CombinerHelper::isCastFree(unsigned Opcode, LLT ToTy, LLT FromTy) const { 186 const TargetLowering &TLI = getTargetLowering(); 187 LLVMContext &Ctx = getContext(); 188 189 switch (Opcode) { 190 case TargetOpcode::G_ANYEXT: 191 case TargetOpcode::G_ZEXT: 192 return TLI.isZExtFree(FromTy, ToTy, Ctx); 193 case TargetOpcode::G_TRUNC: 194 return TLI.isTruncateFree(FromTy, ToTy, Ctx); 195 default: 196 return false; 197 } 198 } 199 200 bool CombinerHelper::matchCastOfSelect(const MachineInstr &CastMI, 201 const MachineInstr &SelectMI, 202 BuildFnTy &MatchInfo) const { 203 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI); 204 const GSelect *Select = cast<GSelect>(&SelectMI); 205 206 if (!MRI.hasOneNonDBGUse(Select->getReg(0))) 207 return false; 208 209 Register Dst = Cast->getReg(0); 210 LLT DstTy = MRI.getType(Dst); 211 LLT CondTy = MRI.getType(Select->getCondReg()); 212 Register TrueReg = Select->getTrueReg(); 213 Register FalseReg = Select->getFalseReg(); 214 LLT SrcTy = MRI.getType(TrueReg); 215 Register Cond = Select->getCondReg(); 216 217 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SELECT, {DstTy, CondTy}})) 218 return false; 219 220 if (!isCastFree(Cast->getOpcode(), DstTy, SrcTy)) 221 return false; 222 223 MatchInfo = [=](MachineIRBuilder &B) { 224 auto True = B.buildInstr(Cast->getOpcode(), {DstTy}, {TrueReg}); 225 auto False = B.buildInstr(Cast->getOpcode(), {DstTy}, {FalseReg}); 226 B.buildSelect(Dst, Cond, True, False); 227 }; 228 229 return true; 230 } 231 232 bool CombinerHelper::matchExtOfExt(const MachineInstr &FirstMI, 233 const MachineInstr &SecondMI, 234 BuildFnTy &MatchInfo) const { 235 const GExtOp *First = cast<GExtOp>(&FirstMI); 236 const GExtOp *Second = cast<GExtOp>(&SecondMI); 237 238 Register Dst = First->getReg(0); 239 Register Src = Second->getSrcReg(); 240 LLT DstTy = MRI.getType(Dst); 241 LLT SrcTy = MRI.getType(Src); 242 243 if (!MRI.hasOneNonDBGUse(Second->getReg(0))) 244 return false; 245 246 // ext of ext -> later ext 247 if (First->getOpcode() == Second->getOpcode() && 248 isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) { 249 if (Second->getOpcode() == TargetOpcode::G_ZEXT) { 250 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; 251 if (Second->getFlag(MachineInstr::MIFlag::NonNeg)) 252 Flag = MachineInstr::MIFlag::NonNeg; 253 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); }; 254 return true; 255 } 256 // not zext -> no flags 257 MatchInfo = [=](MachineIRBuilder &B) { 258 B.buildInstr(Second->getOpcode(), {Dst}, {Src}); 259 }; 260 return true; 261 } 262 263 // anyext of sext/zext -> sext/zext 264 // -> pick anyext as second ext, then ext of ext 265 if (First->getOpcode() == TargetOpcode::G_ANYEXT && 266 isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) { 267 if (Second->getOpcode() == TargetOpcode::G_ZEXT) { 268 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; 269 if (Second->getFlag(MachineInstr::MIFlag::NonNeg)) 270 Flag = MachineInstr::MIFlag::NonNeg; 271 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); }; 272 return true; 273 } 274 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; 275 return true; 276 } 277 278 // sext/zext of anyext -> sext/zext 279 // -> pick anyext as first ext, then ext of ext 280 if (Second->getOpcode() == TargetOpcode::G_ANYEXT && 281 isLegalOrBeforeLegalizer({First->getOpcode(), {DstTy, SrcTy}})) { 282 if (First->getOpcode() == TargetOpcode::G_ZEXT) { 283 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags; 284 if (First->getFlag(MachineInstr::MIFlag::NonNeg)) 285 Flag = MachineInstr::MIFlag::NonNeg; 286 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); }; 287 return true; 288 } 289 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); }; 290 return true; 291 } 292 293 return false; 294 } 295 296 bool CombinerHelper::matchCastOfBuildVector(const MachineInstr &CastMI, 297 const MachineInstr &BVMI, 298 BuildFnTy &MatchInfo) const { 299 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI); 300 const GBuildVector *BV = cast<GBuildVector>(&BVMI); 301 302 if (!MRI.hasOneNonDBGUse(BV->getReg(0))) 303 return false; 304 305 Register Dst = Cast->getReg(0); 306 // The type of the new build vector. 307 LLT DstTy = MRI.getType(Dst); 308 // The scalar or element type of the new build vector. 309 LLT ElemTy = DstTy.getScalarType(); 310 // The scalar or element type of the old build vector. 311 LLT InputElemTy = MRI.getType(BV->getReg(0)).getElementType(); 312 313 // Check legality of new build vector, the scalar casts, and profitability of 314 // the many casts. 315 if (!isLegalOrBeforeLegalizer( 316 {TargetOpcode::G_BUILD_VECTOR, {DstTy, ElemTy}}) || 317 !isLegalOrBeforeLegalizer({Cast->getOpcode(), {ElemTy, InputElemTy}}) || 318 !isCastFree(Cast->getOpcode(), ElemTy, InputElemTy)) 319 return false; 320 321 MatchInfo = [=](MachineIRBuilder &B) { 322 SmallVector<Register> Casts; 323 unsigned Elements = BV->getNumSources(); 324 for (unsigned I = 0; I < Elements; ++I) { 325 auto CastI = 326 B.buildInstr(Cast->getOpcode(), {ElemTy}, {BV->getSourceReg(I)}); 327 Casts.push_back(CastI.getReg(0)); 328 } 329 330 B.buildBuildVector(Dst, Casts); 331 }; 332 333 return true; 334 } 335 336 bool CombinerHelper::matchNarrowBinop(const MachineInstr &TruncMI, 337 const MachineInstr &BinopMI, 338 BuildFnTy &MatchInfo) const { 339 const GTrunc *Trunc = cast<GTrunc>(&TruncMI); 340 const GBinOp *BinOp = cast<GBinOp>(&BinopMI); 341 342 if (!MRI.hasOneNonDBGUse(BinOp->getReg(0))) 343 return false; 344 345 Register Dst = Trunc->getReg(0); 346 LLT DstTy = MRI.getType(Dst); 347 348 // Is narrow binop legal? 349 if (!isLegalOrBeforeLegalizer({BinOp->getOpcode(), {DstTy}})) 350 return false; 351 352 MatchInfo = [=](MachineIRBuilder &B) { 353 auto LHS = B.buildTrunc(DstTy, BinOp->getLHSReg()); 354 auto RHS = B.buildTrunc(DstTy, BinOp->getRHSReg()); 355 B.buildInstr(BinOp->getOpcode(), {Dst}, {LHS, RHS}); 356 }; 357 358 return true; 359 } 360 361 bool CombinerHelper::matchCastOfInteger(const MachineInstr &CastMI, 362 APInt &MatchInfo) const { 363 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI); 364 365 APInt Input = getIConstantFromReg(Cast->getSrcReg(), MRI); 366 367 LLT DstTy = MRI.getType(Cast->getReg(0)); 368 369 if (!isConstantLegalOrBeforeLegalizer(DstTy)) 370 return false; 371 372 switch (Cast->getOpcode()) { 373 case TargetOpcode::G_TRUNC: { 374 MatchInfo = Input.trunc(DstTy.getScalarSizeInBits()); 375 return true; 376 } 377 default: 378 return false; 379 } 380 } 381 382 bool CombinerHelper::matchRedundantSextInReg(MachineInstr &Root, 383 MachineInstr &Other, 384 BuildFnTy &MatchInfo) const { 385 assert(Root.getOpcode() == TargetOpcode::G_SEXT_INREG && 386 Other.getOpcode() == TargetOpcode::G_SEXT_INREG); 387 388 unsigned RootWidth = Root.getOperand(2).getImm(); 389 unsigned OtherWidth = Other.getOperand(2).getImm(); 390 391 Register Dst = Root.getOperand(0).getReg(); 392 Register OtherDst = Other.getOperand(0).getReg(); 393 Register Src = Other.getOperand(1).getReg(); 394 395 if (RootWidth >= OtherWidth) { 396 // The root sext_inreg is entirely redundant because the other one 397 // is narrower. 398 if (!canReplaceReg(Dst, OtherDst, MRI)) 399 return false; 400 401 MatchInfo = [=](MachineIRBuilder &B) { 402 Observer.changingAllUsesOfReg(MRI, Dst); 403 MRI.replaceRegWith(Dst, OtherDst); 404 Observer.finishedChangingAllUsesOfReg(); 405 }; 406 } else { 407 // RootWidth < OtherWidth, rewrite this G_SEXT_INREG with the source of the 408 // other G_SEXT_INREG. 409 MatchInfo = [=](MachineIRBuilder &B) { 410 B.buildSExtInReg(Dst, Src, RootWidth); 411 }; 412 } 413 414 return true; 415 } 416