1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the targeting of the Machinelegalizer class for SPIR-V. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "SPIRVLegalizerInfo.h" 14 #include "SPIRV.h" 15 #include "SPIRVGlobalRegistry.h" 16 #include "SPIRVSubtarget.h" 17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" 18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 19 #include "llvm/CodeGen/MachineInstr.h" 20 #include "llvm/CodeGen/MachineRegisterInfo.h" 21 #include "llvm/CodeGen/TargetOpcodes.h" 22 23 using namespace llvm; 24 using namespace llvm::LegalizeActions; 25 using namespace llvm::LegalityPredicates; 26 27 static const std::set<unsigned> TypeFoldingSupportingOpcs = { 28 TargetOpcode::G_ADD, 29 TargetOpcode::G_FADD, 30 TargetOpcode::G_SUB, 31 TargetOpcode::G_FSUB, 32 TargetOpcode::G_MUL, 33 TargetOpcode::G_FMUL, 34 TargetOpcode::G_SDIV, 35 TargetOpcode::G_UDIV, 36 TargetOpcode::G_FDIV, 37 TargetOpcode::G_SREM, 38 TargetOpcode::G_UREM, 39 TargetOpcode::G_FREM, 40 TargetOpcode::G_FNEG, 41 TargetOpcode::G_CONSTANT, 42 TargetOpcode::G_FCONSTANT, 43 TargetOpcode::G_AND, 44 TargetOpcode::G_OR, 45 TargetOpcode::G_XOR, 46 TargetOpcode::G_SHL, 47 TargetOpcode::G_ASHR, 48 TargetOpcode::G_LSHR, 49 TargetOpcode::G_SELECT, 50 TargetOpcode::G_EXTRACT_VECTOR_ELT, 51 }; 52 53 bool isTypeFoldingSupported(unsigned Opcode) { 54 return TypeFoldingSupportingOpcs.count(Opcode) > 0; 55 } 56 57 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { 58 using namespace TargetOpcode; 59 60 this->ST = &ST; 61 GR = ST.getSPIRVGlobalRegistry(); 62 63 const LLT s1 = LLT::scalar(1); 64 const LLT s8 = LLT::scalar(8); 65 const LLT s16 = LLT::scalar(16); 66 const LLT s32 = LLT::scalar(32); 67 const LLT s64 = LLT::scalar(64); 68 69 const LLT v16s64 = LLT::fixed_vector(16, 64); 70 const LLT v16s32 = LLT::fixed_vector(16, 32); 71 const LLT v16s16 = LLT::fixed_vector(16, 16); 72 const LLT v16s8 = LLT::fixed_vector(16, 8); 73 const LLT v16s1 = LLT::fixed_vector(16, 1); 74 75 const LLT v8s64 = LLT::fixed_vector(8, 64); 76 const LLT v8s32 = LLT::fixed_vector(8, 32); 77 const LLT v8s16 = LLT::fixed_vector(8, 16); 78 const LLT v8s8 = LLT::fixed_vector(8, 8); 79 const LLT v8s1 = LLT::fixed_vector(8, 1); 80 81 const LLT v4s64 = LLT::fixed_vector(4, 64); 82 const LLT v4s32 = LLT::fixed_vector(4, 32); 83 const LLT v4s16 = LLT::fixed_vector(4, 16); 84 const LLT v4s8 = LLT::fixed_vector(4, 8); 85 const LLT v4s1 = LLT::fixed_vector(4, 1); 86 87 const LLT v3s64 = LLT::fixed_vector(3, 64); 88 const LLT v3s32 = LLT::fixed_vector(3, 32); 89 const LLT v3s16 = LLT::fixed_vector(3, 16); 90 const LLT v3s8 = LLT::fixed_vector(3, 8); 91 const LLT v3s1 = LLT::fixed_vector(3, 1); 92 93 const LLT v2s64 = LLT::fixed_vector(2, 64); 94 const LLT v2s32 = LLT::fixed_vector(2, 32); 95 const LLT v2s16 = LLT::fixed_vector(2, 16); 96 const LLT v2s8 = LLT::fixed_vector(2, 8); 97 const LLT v2s1 = LLT::fixed_vector(2, 1); 98 99 const unsigned PSize = ST.getPointerSize(); 100 const LLT p0 = LLT::pointer(0, PSize); // Function 101 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup 102 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant 103 const LLT p3 = LLT::pointer(3, PSize); // Workgroup 104 const LLT p4 = LLT::pointer(4, PSize); // Generic 105 const LLT p5 = LLT::pointer(5, PSize); // Input 106 107 // TODO: remove copy-pasting here by using concatenation in some way. 108 auto allPtrsScalarsAndVectors = { 109 p0, p1, p2, p3, p4, p5, s1, s8, s16, 110 s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, 111 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, 112 v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; 113 114 auto allScalarsAndVectors = { 115 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, 116 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, 117 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; 118 119 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16, 120 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64, 121 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16, 122 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64}; 123 124 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1}; 125 126 auto allIntScalars = {s8, s16, s32, s64}; 127 128 auto allFloatScalarsAndVectors = { 129 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, 130 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; 131 132 auto allFloatAndIntScalars = allIntScalars; 133 134 auto allPtrs = {p0, p1, p2, p3, p4, p5}; 135 auto allWritablePtrs = {p0, p1, p3, p4}; 136 137 for (auto Opc : TypeFoldingSupportingOpcs) 138 getActionDefinitionsBuilder(Opc).custom(); 139 140 getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); 141 142 // TODO: add proper rules for vectors legalization. 143 getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal(); 144 145 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) 146 .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs))); 147 148 getActionDefinitionsBuilder(G_ADDRSPACE_CAST) 149 .legalForCartesianProduct(allPtrs, allPtrs); 150 151 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs)); 152 153 getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors); 154 155 getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors); 156 157 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) 158 .legalForCartesianProduct(allIntScalarsAndVectors, 159 allFloatScalarsAndVectors); 160 161 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP}) 162 .legalForCartesianProduct(allFloatScalarsAndVectors, 163 allScalarsAndVectors); 164 165 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS}) 166 .legalFor(allIntScalarsAndVectors); 167 168 getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct( 169 allIntScalarsAndVectors, allIntScalarsAndVectors); 170 171 getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors); 172 173 getActionDefinitionsBuilder(G_BITCAST).legalIf(all( 174 typeInSet(0, allPtrsScalarsAndVectors), 175 typeInSet(1, allPtrsScalarsAndVectors), 176 LegalityPredicate(([=](const LegalityQuery &Query) { 177 return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits(); 178 })))); 179 180 getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal(); 181 182 getActionDefinitionsBuilder(G_INTTOPTR) 183 .legalForCartesianProduct(allPtrs, allIntScalars); 184 getActionDefinitionsBuilder(G_PTRTOINT) 185 .legalForCartesianProduct(allIntScalars, allPtrs); 186 getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct( 187 allPtrs, allIntScalars); 188 189 // ST.canDirectlyComparePointers() for pointer args is supported in 190 // legalizeCustom(). 191 getActionDefinitionsBuilder(G_ICMP).customIf( 192 all(typeInSet(0, allBoolScalarsAndVectors), 193 typeInSet(1, allPtrsScalarsAndVectors))); 194 195 getActionDefinitionsBuilder(G_FCMP).legalIf( 196 all(typeInSet(0, allBoolScalarsAndVectors), 197 typeInSet(1, allFloatScalarsAndVectors))); 198 199 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, 200 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, 201 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR, 202 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN}) 203 .legalForCartesianProduct(allIntScalars, allWritablePtrs); 204 205 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG) 206 .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs); 207 208 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower(); 209 // TODO: add proper legalization rules. 210 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal(); 211 212 getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO}) 213 .alwaysLegal(); 214 215 // Extensions. 216 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT}) 217 .legalForCartesianProduct(allScalarsAndVectors); 218 219 // FP conversions. 220 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT}) 221 .legalForCartesianProduct(allFloatScalarsAndVectors); 222 223 // Pointer-handling. 224 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); 225 226 // Control-flow. 227 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1}); 228 229 getActionDefinitionsBuilder({G_FPOW, 230 G_FEXP, 231 G_FEXP2, 232 G_FLOG, 233 G_FLOG2, 234 G_FABS, 235 G_FMINNUM, 236 G_FMAXNUM, 237 G_FCEIL, 238 G_FCOS, 239 G_FSIN, 240 G_FSQRT, 241 G_FFLOOR, 242 G_FRINT, 243 G_FNEARBYINT, 244 G_INTRINSIC_ROUND, 245 G_INTRINSIC_TRUNC, 246 G_FMINIMUM, 247 G_FMAXIMUM, 248 G_INTRINSIC_ROUNDEVEN}) 249 .legalFor(allFloatScalarsAndVectors); 250 251 getActionDefinitionsBuilder(G_FCOPYSIGN) 252 .legalForCartesianProduct(allFloatScalarsAndVectors, 253 allFloatScalarsAndVectors); 254 255 getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct( 256 allFloatScalarsAndVectors, allIntScalarsAndVectors); 257 258 getLegacyLegalizerInfo().computeTables(); 259 verify(*ST.getInstrInfo()); 260 } 261 262 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, 263 LegalizerHelper &Helper, 264 MachineRegisterInfo &MRI, 265 SPIRVGlobalRegistry *GR) { 266 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy); 267 GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF()); 268 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT) 269 .addDef(ConvReg) 270 .addUse(Reg); 271 return ConvReg; 272 } 273 274 bool SPIRVLegalizerInfo::legalizeCustom(LegalizerHelper &Helper, 275 MachineInstr &MI) const { 276 auto Opc = MI.getOpcode(); 277 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 278 if (!isTypeFoldingSupported(Opc)) { 279 assert(Opc == TargetOpcode::G_ICMP); 280 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg())); 281 auto &Op0 = MI.getOperand(2); 282 auto &Op1 = MI.getOperand(3); 283 Register Reg0 = Op0.getReg(); 284 Register Reg1 = Op1.getReg(); 285 CmpInst::Predicate Cond = 286 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate()); 287 if ((!ST->canDirectlyComparePointers() || 288 (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) && 289 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) { 290 LLT ConvT = LLT::scalar(ST->getPointerSize()); 291 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(), 292 ST->getPointerSize()); 293 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder); 294 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR)); 295 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR)); 296 } 297 return true; 298 } 299 // TODO: implement legalization for other opcodes. 300 return true; 301 } 302