1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the targeting of the Machinelegalizer class for SPIR-V. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "SPIRVLegalizerInfo.h" 14 #include "SPIRV.h" 15 #include "SPIRVGlobalRegistry.h" 16 #include "SPIRVSubtarget.h" 17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" 18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 19 #include "llvm/CodeGen/MachineInstr.h" 20 #include "llvm/CodeGen/MachineRegisterInfo.h" 21 #include "llvm/CodeGen/TargetOpcodes.h" 22 23 using namespace llvm; 24 using namespace llvm::LegalizeActions; 25 using namespace llvm::LegalityPredicates; 26 27 static const std::set<unsigned> TypeFoldingSupportingOpcs = { 28 TargetOpcode::G_ADD, 29 TargetOpcode::G_FADD, 30 TargetOpcode::G_SUB, 31 TargetOpcode::G_FSUB, 32 TargetOpcode::G_MUL, 33 TargetOpcode::G_FMUL, 34 TargetOpcode::G_SDIV, 35 TargetOpcode::G_UDIV, 36 TargetOpcode::G_FDIV, 37 TargetOpcode::G_SREM, 38 TargetOpcode::G_UREM, 39 TargetOpcode::G_FREM, 40 TargetOpcode::G_FNEG, 41 TargetOpcode::G_CONSTANT, 42 TargetOpcode::G_FCONSTANT, 43 TargetOpcode::G_AND, 44 TargetOpcode::G_OR, 45 TargetOpcode::G_XOR, 46 TargetOpcode::G_SHL, 47 TargetOpcode::G_ASHR, 48 TargetOpcode::G_LSHR, 49 TargetOpcode::G_SELECT, 50 TargetOpcode::G_EXTRACT_VECTOR_ELT, 51 }; 52 53 bool isTypeFoldingSupported(unsigned Opcode) { 54 return TypeFoldingSupportingOpcs.count(Opcode) > 0; 55 } 56 57 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { 58 using namespace TargetOpcode; 59 60 this->ST = &ST; 61 GR = ST.getSPIRVGlobalRegistry(); 62 63 const LLT s1 = LLT::scalar(1); 64 const LLT s8 = LLT::scalar(8); 65 const LLT s16 = LLT::scalar(16); 66 const LLT s32 = LLT::scalar(32); 67 const LLT s64 = LLT::scalar(64); 68 69 const LLT v16s64 = LLT::fixed_vector(16, 64); 70 const LLT v16s32 = LLT::fixed_vector(16, 32); 71 const LLT v16s16 = LLT::fixed_vector(16, 16); 72 const LLT v16s8 = LLT::fixed_vector(16, 8); 73 const LLT v16s1 = LLT::fixed_vector(16, 1); 74 75 const LLT v8s64 = LLT::fixed_vector(8, 64); 76 const LLT v8s32 = LLT::fixed_vector(8, 32); 77 const LLT v8s16 = LLT::fixed_vector(8, 16); 78 const LLT v8s8 = LLT::fixed_vector(8, 8); 79 const LLT v8s1 = LLT::fixed_vector(8, 1); 80 81 const LLT v4s64 = LLT::fixed_vector(4, 64); 82 const LLT v4s32 = LLT::fixed_vector(4, 32); 83 const LLT v4s16 = LLT::fixed_vector(4, 16); 84 const LLT v4s8 = LLT::fixed_vector(4, 8); 85 const LLT v4s1 = LLT::fixed_vector(4, 1); 86 87 const LLT v3s64 = LLT::fixed_vector(3, 64); 88 const LLT v3s32 = LLT::fixed_vector(3, 32); 89 const LLT v3s16 = LLT::fixed_vector(3, 16); 90 const LLT v3s8 = LLT::fixed_vector(3, 8); 91 const LLT v3s1 = LLT::fixed_vector(3, 1); 92 93 const LLT v2s64 = LLT::fixed_vector(2, 64); 94 const LLT v2s32 = LLT::fixed_vector(2, 32); 95 const LLT v2s16 = LLT::fixed_vector(2, 16); 96 const LLT v2s8 = LLT::fixed_vector(2, 8); 97 const LLT v2s1 = LLT::fixed_vector(2, 1); 98 99 const unsigned PSize = ST.getPointerSize(); 100 const LLT p0 = LLT::pointer(0, PSize); // Function 101 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup 102 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant 103 const LLT p3 = LLT::pointer(3, PSize); // Workgroup 104 const LLT p4 = LLT::pointer(4, PSize); // Generic 105 const LLT p5 = LLT::pointer(5, PSize); // Input 106 107 // TODO: remove copy-pasting here by using concatenation in some way. 108 auto allPtrsScalarsAndVectors = { 109 p0, p1, p2, p3, p4, p5, s1, s8, s16, 110 s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, 111 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, 112 v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; 113 114 auto allScalarsAndVectors = { 115 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, 116 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, 117 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; 118 119 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16, 120 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64, 121 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16, 122 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64}; 123 124 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1}; 125 126 auto allIntScalars = {s8, s16, s32, s64}; 127 128 auto allFloatScalarsAndVectors = { 129 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, 130 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; 131 132 auto allFloatAndIntScalars = allIntScalars; 133 134 auto allPtrs = {p0, p1, p2, p3, p4, p5}; 135 auto allWritablePtrs = {p0, p1, p3, p4}; 136 137 for (auto Opc : TypeFoldingSupportingOpcs) 138 getActionDefinitionsBuilder(Opc).custom(); 139 140 getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); 141 142 // TODO: add proper rules for vectors legalization. 143 getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal(); 144 145 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) 146 .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs))); 147 148 getActionDefinitionsBuilder(G_MEMSET).legalIf( 149 all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars))); 150 151 getActionDefinitionsBuilder(G_ADDRSPACE_CAST) 152 .legalForCartesianProduct(allPtrs, allPtrs); 153 154 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs)); 155 156 getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allFloatScalarsAndVectors); 157 158 getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors); 159 160 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) 161 .legalForCartesianProduct(allIntScalarsAndVectors, 162 allFloatScalarsAndVectors); 163 164 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP}) 165 .legalForCartesianProduct(allFloatScalarsAndVectors, 166 allScalarsAndVectors); 167 168 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS}) 169 .legalFor(allIntScalarsAndVectors); 170 171 getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct( 172 allIntScalarsAndVectors, allIntScalarsAndVectors); 173 174 getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors); 175 176 getActionDefinitionsBuilder(G_BITCAST).legalIf(all( 177 typeInSet(0, allPtrsScalarsAndVectors), 178 typeInSet(1, allPtrsScalarsAndVectors), 179 LegalityPredicate(([=](const LegalityQuery &Query) { 180 return Query.Types[0].getSizeInBits() == Query.Types[1].getSizeInBits(); 181 })))); 182 183 getActionDefinitionsBuilder(G_IMPLICIT_DEF).alwaysLegal(); 184 185 getActionDefinitionsBuilder(G_INTTOPTR) 186 .legalForCartesianProduct(allPtrs, allIntScalars); 187 getActionDefinitionsBuilder(G_PTRTOINT) 188 .legalForCartesianProduct(allIntScalars, allPtrs); 189 getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct( 190 allPtrs, allIntScalars); 191 192 // ST.canDirectlyComparePointers() for pointer args is supported in 193 // legalizeCustom(). 194 getActionDefinitionsBuilder(G_ICMP).customIf( 195 all(typeInSet(0, allBoolScalarsAndVectors), 196 typeInSet(1, allPtrsScalarsAndVectors))); 197 198 getActionDefinitionsBuilder(G_FCMP).legalIf( 199 all(typeInSet(0, allBoolScalarsAndVectors), 200 typeInSet(1, allFloatScalarsAndVectors))); 201 202 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, 203 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, 204 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR, 205 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN}) 206 .legalForCartesianProduct(allIntScalars, allWritablePtrs); 207 208 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG) 209 .legalForCartesianProduct(allFloatAndIntScalars, allWritablePtrs); 210 211 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower(); 212 // TODO: add proper legalization rules. 213 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal(); 214 215 getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO}) 216 .alwaysLegal(); 217 218 // Extensions. 219 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT}) 220 .legalForCartesianProduct(allScalarsAndVectors); 221 222 // FP conversions. 223 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT}) 224 .legalForCartesianProduct(allFloatScalarsAndVectors); 225 226 // Pointer-handling. 227 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); 228 229 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32. 230 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32}); 231 232 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to 233 // tighten these requirements. Many of these math functions are only legal on 234 // specific bitwidths, so they are not selectable for 235 // allFloatScalarsAndVectors. 236 getActionDefinitionsBuilder({G_FPOW, 237 G_FEXP, 238 G_FEXP2, 239 G_FLOG, 240 G_FLOG2, 241 G_FLOG10, 242 G_FABS, 243 G_FMINNUM, 244 G_FMAXNUM, 245 G_FCEIL, 246 G_FCOS, 247 G_FSIN, 248 G_FSQRT, 249 G_FFLOOR, 250 G_FRINT, 251 G_FNEARBYINT, 252 G_INTRINSIC_ROUND, 253 G_INTRINSIC_TRUNC, 254 G_FMINIMUM, 255 G_FMAXIMUM, 256 G_INTRINSIC_ROUNDEVEN}) 257 .legalFor(allFloatScalarsAndVectors); 258 259 getActionDefinitionsBuilder(G_FCOPYSIGN) 260 .legalForCartesianProduct(allFloatScalarsAndVectors, 261 allFloatScalarsAndVectors); 262 263 getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct( 264 allFloatScalarsAndVectors, allIntScalarsAndVectors); 265 266 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) { 267 getActionDefinitionsBuilder( 268 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF}) 269 .legalForCartesianProduct(allIntScalarsAndVectors, 270 allIntScalarsAndVectors); 271 272 // Struct return types become a single scalar, so cannot easily legalize. 273 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal(); 274 } 275 276 getLegacyLegalizerInfo().computeTables(); 277 verify(*ST.getInstrInfo()); 278 } 279 280 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, 281 LegalizerHelper &Helper, 282 MachineRegisterInfo &MRI, 283 SPIRVGlobalRegistry *GR) { 284 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy); 285 GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF()); 286 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT) 287 .addDef(ConvReg) 288 .addUse(Reg); 289 return ConvReg; 290 } 291 292 bool SPIRVLegalizerInfo::legalizeCustom( 293 LegalizerHelper &Helper, MachineInstr &MI, 294 LostDebugLocObserver &LocObserver) const { 295 auto Opc = MI.getOpcode(); 296 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 297 if (!isTypeFoldingSupported(Opc)) { 298 assert(Opc == TargetOpcode::G_ICMP); 299 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg())); 300 auto &Op0 = MI.getOperand(2); 301 auto &Op1 = MI.getOperand(3); 302 Register Reg0 = Op0.getReg(); 303 Register Reg1 = Op1.getReg(); 304 CmpInst::Predicate Cond = 305 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate()); 306 if ((!ST->canDirectlyComparePointers() || 307 (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) && 308 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) { 309 LLT ConvT = LLT::scalar(ST->getPointerSize()); 310 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(), 311 ST->getPointerSize()); 312 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder); 313 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR)); 314 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR)); 315 } 316 return true; 317 } 318 // TODO: implement legalization for other opcodes. 319 return true; 320 } 321