1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the targeting of the Machinelegalizer class for SPIR-V. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "SPIRVLegalizerInfo.h" 14 #include "SPIRV.h" 15 #include "SPIRVGlobalRegistry.h" 16 #include "SPIRVSubtarget.h" 17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" 18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 19 #include "llvm/CodeGen/MachineInstr.h" 20 #include "llvm/CodeGen/MachineRegisterInfo.h" 21 #include "llvm/CodeGen/TargetOpcodes.h" 22 23 using namespace llvm; 24 using namespace llvm::LegalizeActions; 25 using namespace llvm::LegalityPredicates; 26 27 static const std::set<unsigned> TypeFoldingSupportingOpcs = { 28 TargetOpcode::G_ADD, 29 TargetOpcode::G_FADD, 30 TargetOpcode::G_SUB, 31 TargetOpcode::G_FSUB, 32 TargetOpcode::G_MUL, 33 TargetOpcode::G_FMUL, 34 TargetOpcode::G_SDIV, 35 TargetOpcode::G_UDIV, 36 TargetOpcode::G_FDIV, 37 TargetOpcode::G_SREM, 38 TargetOpcode::G_UREM, 39 TargetOpcode::G_FREM, 40 TargetOpcode::G_FNEG, 41 TargetOpcode::G_CONSTANT, 42 TargetOpcode::G_FCONSTANT, 43 TargetOpcode::G_AND, 44 TargetOpcode::G_OR, 45 TargetOpcode::G_XOR, 46 TargetOpcode::G_SHL, 47 TargetOpcode::G_ASHR, 48 TargetOpcode::G_LSHR, 49 TargetOpcode::G_SELECT, 50 TargetOpcode::G_EXTRACT_VECTOR_ELT, 51 }; 52 53 bool isTypeFoldingSupported(unsigned Opcode) { 54 return TypeFoldingSupportingOpcs.count(Opcode) > 0; 55 } 56 57 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { 58 using namespace TargetOpcode; 59 60 this->ST = &ST; 61 GR = ST.getSPIRVGlobalRegistry(); 62 63 const LLT s1 = LLT::scalar(1); 64 const LLT s8 = LLT::scalar(8); 65 const LLT s16 = LLT::scalar(16); 66 const LLT s32 = LLT::scalar(32); 67 const LLT s64 = LLT::scalar(64); 68 69 const LLT v16s64 = LLT::fixed_vector(16, 64); 70 const LLT v16s32 = LLT::fixed_vector(16, 32); 71 const LLT v16s16 = LLT::fixed_vector(16, 16); 72 const LLT v16s8 = LLT::fixed_vector(16, 8); 73 const LLT v16s1 = LLT::fixed_vector(16, 1); 74 75 const LLT v8s64 = LLT::fixed_vector(8, 64); 76 const LLT v8s32 = LLT::fixed_vector(8, 32); 77 const LLT v8s16 = LLT::fixed_vector(8, 16); 78 const LLT v8s8 = LLT::fixed_vector(8, 8); 79 const LLT v8s1 = LLT::fixed_vector(8, 1); 80 81 const LLT v4s64 = LLT::fixed_vector(4, 64); 82 const LLT v4s32 = LLT::fixed_vector(4, 32); 83 const LLT v4s16 = LLT::fixed_vector(4, 16); 84 const LLT v4s8 = LLT::fixed_vector(4, 8); 85 const LLT v4s1 = LLT::fixed_vector(4, 1); 86 87 const LLT v3s64 = LLT::fixed_vector(3, 64); 88 const LLT v3s32 = LLT::fixed_vector(3, 32); 89 const LLT v3s16 = LLT::fixed_vector(3, 16); 90 const LLT v3s8 = LLT::fixed_vector(3, 8); 91 const LLT v3s1 = LLT::fixed_vector(3, 1); 92 93 const LLT v2s64 = LLT::fixed_vector(2, 64); 94 const LLT v2s32 = LLT::fixed_vector(2, 32); 95 const LLT v2s16 = LLT::fixed_vector(2, 16); 96 const LLT v2s8 = LLT::fixed_vector(2, 8); 97 const LLT v2s1 = LLT::fixed_vector(2, 1); 98 99 const unsigned PSize = ST.getPointerSize(); 100 const LLT p0 = LLT::pointer(0, PSize); // Function 101 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup 102 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant 103 const LLT p3 = LLT::pointer(3, PSize); // Workgroup 104 const LLT p4 = LLT::pointer(4, PSize); // Generic 105 const LLT p5 = LLT::pointer(5, PSize); // Input 106 107 // TODO: remove copy-pasting here by using concatenation in some way. 108 auto allPtrsScalarsAndVectors = { 109 p0, p1, p2, p3, p4, p5, s1, s8, s16, 110 s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, 111 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, 112 v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; 113 114 auto allScalarsAndVectors = { 115 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, 116 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, 117 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; 118 119 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16, 120 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64, 121 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16, 122 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64}; 123 124 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1}; 125 126 auto allIntScalars = {s8, s16, s32, s64}; 127 128 auto allFloatScalarsAndVectors = { 129 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, 130 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; 131 132 auto allFloatAndIntScalars = allIntScalars; 133 134 auto allPtrs = {p0, p1, p2, p3, p4, p5}; 135 auto allWritablePtrs = {p0, p1, p3, p4}; 136 137 for (auto Opc : TypeFoldingSupportingOpcs) 138 getActionDefinitionsBuilder(Opc).custom(); 139 140 getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); 141 142 // TODO: add proper rules for vectors legalization. 143 getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal(); 144 145 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) 146 .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs))); 147 148 getActionDefinitionsBuilder(G_MEMSET).legalIf( 149 all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars))); 150 151 getActionDefinitionsBuilder(G_ADDRSPACE_CAST) 152 .legalForCartesianProduct(allPtrs, allPtrs); 153 154 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs)); 155 156 getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors); 157 158 getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors); 159 160 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) 161 .legalForCartesianProduct(allIntScalarsAndVectors, 162 allFloatScalarsAndVectors); 163 164 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP}) 165 .legalForCartesianProduct(allFloatScalarsAndVectors, 166 allScalarsAndVectors); 167 168 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS}) 169 .legalFor(allIntScalarsAndVectors); 170 171 getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct( 172 allIntScalarsAndVectors, allIntScalarsAndVectors); 173 174 getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors); 175 176 getActionDefinitionsBuilder(G_BITCAST).legalIf(all( 177 typeInSet(0, allPtrsScalarsAndVectors), 178 typeInSet(1, allPtrsScalarsAndVectors), 179 LegalityPredicate(([=](const LegalityQuery &Query) { 180 return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits(); 181 })))); 182 183 getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal(); 184 185 getActionDefinitionsBuilder(G_INTTOPTR) 186 .legalForCartesianProduct(allPtrs, allIntScalars); 187 getActionDefinitionsBuilder(G_PTRTOINT) 188 .legalForCartesianProduct(allIntScalars, allPtrs); 189 getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct( 190 allPtrs, allIntScalars); 191 192 // ST.canDirectlyComparePointers() for pointer args is supported in 193 // legalizeCustom(). 194 getActionDefinitionsBuilder(G_ICMP).customIf( 195 all(typeInSet(0, allBoolScalarsAndVectors), 196 typeInSet(1, allPtrsScalarsAndVectors))); 197 198 getActionDefinitionsBuilder(G_FCMP).legalIf( 199 all(typeInSet(0, allBoolScalarsAndVectors), 200 typeInSet(1, allFloatScalarsAndVectors))); 201 202 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, 203 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, 204 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR, 205 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN}) 206 .legalForCartesianProduct(allIntScalars, allWritablePtrs); 207 208 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG) 209 .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs); 210 211 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower(); 212 // TODO: add proper legalization rules. 213 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal(); 214 215 getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO}) 216 .alwaysLegal(); 217 218 // Extensions. 219 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT}) 220 .legalForCartesianProduct(allScalarsAndVectors); 221 222 // FP conversions. 223 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT}) 224 .legalForCartesianProduct(allFloatScalarsAndVectors); 225 226 // Pointer-handling. 227 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); 228 229 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32. 230 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32}); 231 232 getActionDefinitionsBuilder({G_FPOW, 233 G_FEXP, 234 G_FEXP2, 235 G_FLOG, 236 G_FLOG2, 237 G_FABS, 238 G_FMINNUM, 239 G_FMAXNUM, 240 G_FCEIL, 241 G_FCOS, 242 G_FSIN, 243 G_FSQRT, 244 G_FFLOOR, 245 G_FRINT, 246 G_FNEARBYINT, 247 G_INTRINSIC_ROUND, 248 G_INTRINSIC_TRUNC, 249 G_FMINIMUM, 250 G_FMAXIMUM, 251 G_INTRINSIC_ROUNDEVEN}) 252 .legalFor(allFloatScalarsAndVectors); 253 254 getActionDefinitionsBuilder(G_FCOPYSIGN) 255 .legalForCartesianProduct(allFloatScalarsAndVectors, 256 allFloatScalarsAndVectors); 257 258 getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct( 259 allFloatScalarsAndVectors, allIntScalarsAndVectors); 260 261 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) { 262 getActionDefinitionsBuilder(G_FLOG10).legalFor(allFloatScalarsAndVectors); 263 264 getActionDefinitionsBuilder( 265 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF}) 266 .legalForCartesianProduct(allIntScalarsAndVectors, 267 allIntScalarsAndVectors); 268 269 // Struct return types become a single scalar, so cannot easily legalize. 270 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal(); 271 } 272 273 getLegacyLegalizerInfo().computeTables(); 274 verify(*ST.getInstrInfo()); 275 } 276 277 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, 278 LegalizerHelper &Helper, 279 MachineRegisterInfo &MRI, 280 SPIRVGlobalRegistry *GR) { 281 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy); 282 GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF()); 283 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT) 284 .addDef(ConvReg) 285 .addUse(Reg); 286 return ConvReg; 287 } 288 289 bool SPIRVLegalizerInfo::legalizeCustom(LegalizerHelper &Helper, 290 MachineInstr &MI) const { 291 auto Opc = MI.getOpcode(); 292 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 293 if (!isTypeFoldingSupported(Opc)) { 294 assert(Opc == TargetOpcode::G_ICMP); 295 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg())); 296 auto &Op0 = MI.getOperand(2); 297 auto &Op1 = MI.getOperand(3); 298 Register Reg0 = Op0.getReg(); 299 Register Reg1 = Op1.getReg(); 300 CmpInst::Predicate Cond = 301 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate()); 302 if ((!ST->canDirectlyComparePointers() || 303 (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) && 304 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) { 305 LLT ConvT = LLT::scalar(ST->getPointerSize()); 306 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(), 307 ST->getPointerSize()); 308 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder); 309 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR)); 310 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR)); 311 } 312 return true; 313 } 314 // TODO: implement legalization for other opcodes. 315 return true; 316 } 317