1 //===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the targeting of the Machinelegalizer class for SPIR-V. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "SPIRVLegalizerInfo.h" 14 #include "SPIRV.h" 15 #include "SPIRVGlobalRegistry.h" 16 #include "SPIRVSubtarget.h" 17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" 18 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 19 #include "llvm/CodeGen/MachineInstr.h" 20 #include "llvm/CodeGen/MachineRegisterInfo.h" 21 #include "llvm/CodeGen/TargetOpcodes.h" 22 23 using namespace llvm; 24 using namespace llvm::LegalizeActions; 25 using namespace llvm::LegalityPredicates; 26 27 static const std::set<unsigned> TypeFoldingSupportingOpcs = { 28 TargetOpcode::G_ADD, 29 TargetOpcode::G_FADD, 30 TargetOpcode::G_SUB, 31 TargetOpcode::G_FSUB, 32 TargetOpcode::G_MUL, 33 TargetOpcode::G_FMUL, 34 TargetOpcode::G_SDIV, 35 TargetOpcode::G_UDIV, 36 TargetOpcode::G_FDIV, 37 TargetOpcode::G_SREM, 38 TargetOpcode::G_UREM, 39 TargetOpcode::G_FREM, 40 TargetOpcode::G_FNEG, 41 TargetOpcode::G_CONSTANT, 42 TargetOpcode::G_FCONSTANT, 43 TargetOpcode::G_AND, 44 TargetOpcode::G_OR, 45 TargetOpcode::G_XOR, 46 TargetOpcode::G_SHL, 47 TargetOpcode::G_ASHR, 48 TargetOpcode::G_LSHR, 49 TargetOpcode::G_SELECT, 50 TargetOpcode::G_EXTRACT_VECTOR_ELT, 51 }; 52 53 bool isTypeFoldingSupported(unsigned Opcode) { 54 return TypeFoldingSupportingOpcs.count(Opcode) > 0; 55 } 56 57 SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) { 58 using namespace TargetOpcode; 59 60 this->ST = &ST; 61 GR = ST.getSPIRVGlobalRegistry(); 62 63 const LLT s1 = LLT::scalar(1); 64 const LLT s8 = LLT::scalar(8); 65 const LLT s16 = LLT::scalar(16); 66 const LLT s32 = LLT::scalar(32); 67 const LLT s64 = LLT::scalar(64); 68 69 const LLT v16s64 = LLT::fixed_vector(16, 64); 70 const LLT v16s32 = LLT::fixed_vector(16, 32); 71 const LLT v16s16 = LLT::fixed_vector(16, 16); 72 const LLT v16s8 = LLT::fixed_vector(16, 8); 73 const LLT v16s1 = LLT::fixed_vector(16, 1); 74 75 const LLT v8s64 = LLT::fixed_vector(8, 64); 76 const LLT v8s32 = LLT::fixed_vector(8, 32); 77 const LLT v8s16 = LLT::fixed_vector(8, 16); 78 const LLT v8s8 = LLT::fixed_vector(8, 8); 79 const LLT v8s1 = LLT::fixed_vector(8, 1); 80 81 const LLT v4s64 = LLT::fixed_vector(4, 64); 82 const LLT v4s32 = LLT::fixed_vector(4, 32); 83 const LLT v4s16 = LLT::fixed_vector(4, 16); 84 const LLT v4s8 = LLT::fixed_vector(4, 8); 85 const LLT v4s1 = LLT::fixed_vector(4, 1); 86 87 const LLT v3s64 = LLT::fixed_vector(3, 64); 88 const LLT v3s32 = LLT::fixed_vector(3, 32); 89 const LLT v3s16 = LLT::fixed_vector(3, 16); 90 const LLT v3s8 = LLT::fixed_vector(3, 8); 91 const LLT v3s1 = LLT::fixed_vector(3, 1); 92 93 const LLT v2s64 = LLT::fixed_vector(2, 64); 94 const LLT v2s32 = LLT::fixed_vector(2, 32); 95 const LLT v2s16 = LLT::fixed_vector(2, 16); 96 const LLT v2s8 = LLT::fixed_vector(2, 8); 97 const LLT v2s1 = LLT::fixed_vector(2, 1); 98 99 const unsigned PSize = ST.getPointerSize(); 100 const LLT p0 = LLT::pointer(0, PSize); // Function 101 const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup 102 const LLT p2 = LLT::pointer(2, PSize); // UniformConstant 103 const LLT p3 = LLT::pointer(3, PSize); // Workgroup 104 const LLT p4 = LLT::pointer(4, PSize); // Generic 105 const LLT p5 = 106 LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device) 107 const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host) 108 109 // TODO: remove copy-pasting here by using concatenation in some way. 110 auto allPtrsScalarsAndVectors = { 111 p0, p1, p2, p3, p4, p5, p6, s1, s8, s16, 112 s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16, 113 v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16, 114 v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; 115 116 auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, 117 v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, 118 v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, 119 v16s8, v16s16, v16s32, v16s64}; 120 121 auto allScalarsAndVectors = { 122 s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, 123 v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, 124 v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64}; 125 126 auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16, 127 v2s32, v2s64, v3s8, v3s16, v3s32, v3s64, 128 v4s8, v4s16, v4s32, v4s64, v8s8, v8s16, 129 v8s32, v8s64, v16s8, v16s16, v16s32, v16s64}; 130 131 auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1}; 132 133 auto allIntScalars = {s8, s16, s32, s64}; 134 135 auto allFloatScalars = {s16, s32, s64}; 136 137 auto allFloatScalarsAndVectors = { 138 s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64, 139 v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64}; 140 141 auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1, 142 p2, p3, p4, p5, p6}; 143 144 auto allPtrs = {p0, p1, p2, p3, p4, p5, p6}; 145 auto allWritablePtrs = {p0, p1, p3, p4, p5, p6}; 146 147 for (auto Opc : TypeFoldingSupportingOpcs) 148 getActionDefinitionsBuilder(Opc).custom(); 149 150 getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal(); 151 152 // TODO: add proper rules for vectors legalization. 153 getActionDefinitionsBuilder( 154 {G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR}) 155 .alwaysLegal(); 156 157 // Vector Reduction Operations 158 getActionDefinitionsBuilder( 159 {G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX, 160 G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN, 161 G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM, 162 G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR}) 163 .legalFor(allVectors) 164 .scalarize(1) 165 .lower(); 166 167 getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL}) 168 .scalarize(2) 169 .lower(); 170 171 // Merge/Unmerge 172 // TODO: add proper legalization rules. 173 getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal(); 174 175 getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE}) 176 .legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs))); 177 178 getActionDefinitionsBuilder(G_MEMSET).legalIf( 179 all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars))); 180 181 getActionDefinitionsBuilder(G_ADDRSPACE_CAST) 182 .legalForCartesianProduct(allPtrs, allPtrs); 183 184 getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs)); 185 186 getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allIntScalarsAndVectors); 187 188 getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors); 189 190 getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI}) 191 .legalForCartesianProduct(allIntScalarsAndVectors, 192 allFloatScalarsAndVectors); 193 194 getActionDefinitionsBuilder({G_SITOFP, G_UITOFP}) 195 .legalForCartesianProduct(allFloatScalarsAndVectors, 196 allScalarsAndVectors); 197 198 getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS}) 199 .legalFor(allIntScalarsAndVectors); 200 201 getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct( 202 allIntScalarsAndVectors, allIntScalarsAndVectors); 203 204 getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors); 205 206 getActionDefinitionsBuilder(G_BITCAST).legalIf( 207 all(typeInSet(0, allPtrsScalarsAndVectors), 208 typeInSet(1, allPtrsScalarsAndVectors))); 209 210 getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal(); 211 212 getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal(); 213 214 getActionDefinitionsBuilder(G_INTTOPTR) 215 .legalForCartesianProduct(allPtrs, allIntScalars); 216 getActionDefinitionsBuilder(G_PTRTOINT) 217 .legalForCartesianProduct(allIntScalars, allPtrs); 218 getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct( 219 allPtrs, allIntScalars); 220 221 // ST.canDirectlyComparePointers() for pointer args is supported in 222 // legalizeCustom(). 223 getActionDefinitionsBuilder(G_ICMP).customIf( 224 all(typeInSet(0, allBoolScalarsAndVectors), 225 typeInSet(1, allPtrsScalarsAndVectors))); 226 227 getActionDefinitionsBuilder(G_FCMP).legalIf( 228 all(typeInSet(0, allBoolScalarsAndVectors), 229 typeInSet(1, allFloatScalarsAndVectors))); 230 231 getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND, 232 G_ATOMICRMW_MAX, G_ATOMICRMW_MIN, 233 G_ATOMICRMW_SUB, G_ATOMICRMW_XOR, 234 G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN}) 235 .legalForCartesianProduct(allIntScalars, allWritablePtrs); 236 237 getActionDefinitionsBuilder( 238 {G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX}) 239 .legalForCartesianProduct(allFloatScalars, allWritablePtrs); 240 241 getActionDefinitionsBuilder(G_ATOMICRMW_XCHG) 242 .legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allWritablePtrs); 243 244 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower(); 245 // TODO: add proper legalization rules. 246 getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal(); 247 248 getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO}) 249 .alwaysLegal(); 250 251 // Extensions. 252 getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT}) 253 .legalForCartesianProduct(allScalarsAndVectors); 254 255 // FP conversions. 256 getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT}) 257 .legalForCartesianProduct(allFloatScalarsAndVectors); 258 259 // Pointer-handling. 260 getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0}); 261 262 // Control-flow. In some cases (e.g. constants) s1 may be promoted to s32. 263 getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32}); 264 265 // TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to 266 // tighten these requirements. Many of these math functions are only legal on 267 // specific bitwidths, so they are not selectable for 268 // allFloatScalarsAndVectors. 269 getActionDefinitionsBuilder({G_FPOW, 270 G_FEXP, 271 G_FEXP2, 272 G_FLOG, 273 G_FLOG2, 274 G_FLOG10, 275 G_FABS, 276 G_FMINNUM, 277 G_FMAXNUM, 278 G_FCEIL, 279 G_FCOS, 280 G_FSIN, 281 G_FTAN, 282 G_FACOS, 283 G_FASIN, 284 G_FATAN, 285 G_FCOSH, 286 G_FSINH, 287 G_FTANH, 288 G_FSQRT, 289 G_FFLOOR, 290 G_FRINT, 291 G_FNEARBYINT, 292 G_INTRINSIC_ROUND, 293 G_INTRINSIC_TRUNC, 294 G_FMINIMUM, 295 G_FMAXIMUM, 296 G_INTRINSIC_ROUNDEVEN}) 297 .legalFor(allFloatScalarsAndVectors); 298 299 getActionDefinitionsBuilder(G_FCOPYSIGN) 300 .legalForCartesianProduct(allFloatScalarsAndVectors, 301 allFloatScalarsAndVectors); 302 303 getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct( 304 allFloatScalarsAndVectors, allIntScalarsAndVectors); 305 306 if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) { 307 getActionDefinitionsBuilder( 308 {G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF}) 309 .legalForCartesianProduct(allIntScalarsAndVectors, 310 allIntScalarsAndVectors); 311 312 // Struct return types become a single scalar, so cannot easily legalize. 313 getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal(); 314 315 // supported saturation arithmetic 316 getActionDefinitionsBuilder({G_SADDSAT, G_UADDSAT, G_SSUBSAT, G_USUBSAT}) 317 .legalFor(allIntScalarsAndVectors); 318 } 319 320 getLegacyLegalizerInfo().computeTables(); 321 verify(*ST.getInstrInfo()); 322 } 323 324 static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType, 325 LegalizerHelper &Helper, 326 MachineRegisterInfo &MRI, 327 SPIRVGlobalRegistry *GR) { 328 Register ConvReg = MRI.createGenericVirtualRegister(ConvTy); 329 GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF()); 330 Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT) 331 .addDef(ConvReg) 332 .addUse(Reg); 333 return ConvReg; 334 } 335 336 bool SPIRVLegalizerInfo::legalizeCustom( 337 LegalizerHelper &Helper, MachineInstr &MI, 338 LostDebugLocObserver &LocObserver) const { 339 auto Opc = MI.getOpcode(); 340 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); 341 if (!isTypeFoldingSupported(Opc)) { 342 assert(Opc == TargetOpcode::G_ICMP); 343 assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg())); 344 auto &Op0 = MI.getOperand(2); 345 auto &Op1 = MI.getOperand(3); 346 Register Reg0 = Op0.getReg(); 347 Register Reg1 = Op1.getReg(); 348 CmpInst::Predicate Cond = 349 static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate()); 350 if ((!ST->canDirectlyComparePointers() || 351 (Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) && 352 MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) { 353 LLT ConvT = LLT::scalar(ST->getPointerSize()); 354 Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(), 355 ST->getPointerSize()); 356 SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder); 357 Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR)); 358 Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR)); 359 } 360 return true; 361 } 362 // TODO: implement legalization for other opcodes. 363 return true; 364 } 365