1 //===- SPIRVISelLowering.cpp - SPIR-V DAG Lowering Impl ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the SPIRVTargetLowering class. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "SPIRVISelLowering.h" 14 #include "SPIRV.h" 15 #include "SPIRVInstrInfo.h" 16 #include "SPIRVRegisterBankInfo.h" 17 #include "SPIRVRegisterInfo.h" 18 #include "SPIRVSubtarget.h" 19 #include "SPIRVTargetMachine.h" 20 #include "llvm/CodeGen/MachineInstrBuilder.h" 21 #include "llvm/CodeGen/MachineRegisterInfo.h" 22 #include "llvm/IR/IntrinsicsSPIRV.h" 23 24 #define DEBUG_TYPE "spirv-lower" 25 26 using namespace llvm; 27 28 unsigned SPIRVTargetLowering::getNumRegistersForCallingConv( 29 LLVMContext &Context, CallingConv::ID CC, EVT VT) const { 30 // This code avoids CallLowering fail inside getVectorTypeBreakdown 31 // on v3i1 arguments. Maybe we need to return 1 for all types. 32 // TODO: remove it once this case is supported by the default implementation. 33 if (VT.isVector() && VT.getVectorNumElements() == 3 && 34 (VT.getVectorElementType() == MVT::i1 || 35 VT.getVectorElementType() == MVT::i8)) 36 return 1; 37 if (!VT.isVector() && VT.isInteger() && VT.getSizeInBits() <= 64) 38 return 1; 39 return getNumRegisters(Context, VT); 40 } 41 42 MVT SPIRVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context, 43 CallingConv::ID CC, 44 EVT VT) const { 45 // This code avoids CallLowering fail inside getVectorTypeBreakdown 46 // on v3i1 arguments. Maybe we need to return i32 for all types. 47 // TODO: remove it once this case is supported by the default implementation. 48 if (VT.isVector() && VT.getVectorNumElements() == 3) { 49 if (VT.getVectorElementType() == MVT::i1) 50 return MVT::v4i1; 51 else if (VT.getVectorElementType() == MVT::i8) 52 return MVT::v4i8; 53 } 54 return getRegisterType(Context, VT); 55 } 56 57 bool SPIRVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, 58 const CallInst &I, 59 MachineFunction &MF, 60 unsigned Intrinsic) const { 61 unsigned AlignIdx = 3; 62 switch (Intrinsic) { 63 case Intrinsic::spv_load: 64 AlignIdx = 2; 65 [[fallthrough]]; 66 case Intrinsic::spv_store: { 67 if (I.getNumOperands() >= AlignIdx + 1) { 68 auto *AlignOp = cast<ConstantInt>(I.getOperand(AlignIdx)); 69 Info.align = Align(AlignOp->getZExtValue()); 70 } 71 Info.flags = static_cast<MachineMemOperand::Flags>( 72 cast<ConstantInt>(I.getOperand(AlignIdx - 1))->getZExtValue()); 73 Info.memVT = MVT::i64; 74 // TODO: take into account opaque pointers (don't use getElementType). 75 // MVT::getVT(PtrTy->getElementType()); 76 return true; 77 break; 78 } 79 default: 80 break; 81 } 82 return false; 83 } 84 85 std::pair<unsigned, const TargetRegisterClass *> 86 SPIRVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI, 87 StringRef Constraint, 88 MVT VT) const { 89 const TargetRegisterClass *RC = nullptr; 90 if (Constraint.starts_with("{")) 91 return std::make_pair(0u, RC); 92 93 if (VT.isFloatingPoint()) 94 RC = VT.isVector() ? &SPIRV::vfIDRegClass 95 : (VT.getScalarSizeInBits() > 32 ? &SPIRV::fID64RegClass 96 : &SPIRV::fIDRegClass); 97 else if (VT.isInteger()) 98 RC = VT.isVector() ? &SPIRV::vIDRegClass 99 : (VT.getScalarSizeInBits() > 32 ? &SPIRV::ID64RegClass 100 : &SPIRV::IDRegClass); 101 else 102 RC = &SPIRV::IDRegClass; 103 104 return std::make_pair(0u, RC); 105 } 106 107 inline Register getTypeReg(MachineRegisterInfo *MRI, Register OpReg) { 108 SPIRVType *TypeInst = MRI->getVRegDef(OpReg); 109 return TypeInst && TypeInst->getOpcode() == SPIRV::OpFunctionParameter 110 ? TypeInst->getOperand(1).getReg() 111 : OpReg; 112 } 113 114 static void doInsertBitcast(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, 115 SPIRVGlobalRegistry &GR, MachineInstr &I, 116 Register OpReg, unsigned OpIdx, 117 SPIRVType *NewPtrType) { 118 Register NewReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); 119 MachineIRBuilder MIB(I); 120 bool Res = MIB.buildInstr(SPIRV::OpBitcast) 121 .addDef(NewReg) 122 .addUse(GR.getSPIRVTypeID(NewPtrType)) 123 .addUse(OpReg) 124 .constrainAllUses(*STI.getInstrInfo(), *STI.getRegisterInfo(), 125 *STI.getRegBankInfo()); 126 if (!Res) 127 report_fatal_error("insert validation bitcast: cannot constrain all uses"); 128 MRI->setRegClass(NewReg, &SPIRV::IDRegClass); 129 GR.assignSPIRVTypeToVReg(NewPtrType, NewReg, MIB.getMF()); 130 I.getOperand(OpIdx).setReg(NewReg); 131 } 132 133 static SPIRVType *createNewPtrType(SPIRVGlobalRegistry &GR, MachineInstr &I, 134 SPIRVType *OpType, bool ReuseType, 135 bool EmitIR, SPIRVType *ResType, 136 const Type *ResTy) { 137 SPIRV::StorageClass::StorageClass SC = 138 static_cast<SPIRV::StorageClass::StorageClass>( 139 OpType->getOperand(1).getImm()); 140 MachineIRBuilder MIB(I); 141 SPIRVType *NewBaseType = 142 ReuseType ? ResType 143 : GR.getOrCreateSPIRVType( 144 ResTy, MIB, SPIRV::AccessQualifier::ReadWrite, EmitIR); 145 return GR.getOrCreateSPIRVPointerType(NewBaseType, MIB, SC); 146 } 147 148 // Insert a bitcast before the instruction to keep SPIR-V code valid 149 // when there is a type mismatch between results and operand types. 150 static void validatePtrTypes(const SPIRVSubtarget &STI, 151 MachineRegisterInfo *MRI, SPIRVGlobalRegistry &GR, 152 MachineInstr &I, unsigned OpIdx, 153 SPIRVType *ResType, const Type *ResTy = nullptr) { 154 // Get operand type 155 MachineFunction *MF = I.getParent()->getParent(); 156 Register OpReg = I.getOperand(OpIdx).getReg(); 157 Register OpTypeReg = getTypeReg(MRI, OpReg); 158 SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF); 159 if (!ResType || !OpType || OpType->getOpcode() != SPIRV::OpTypePointer) 160 return; 161 // Get operand's pointee type 162 Register ElemTypeReg = OpType->getOperand(2).getReg(); 163 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(ElemTypeReg, MF); 164 if (!ElemType) 165 return; 166 // Check if we need a bitcast to make a statement valid 167 bool IsSameMF = MF == ResType->getParent()->getParent(); 168 bool IsEqualTypes = IsSameMF ? ElemType == ResType 169 : GR.getTypeForSPIRVType(ElemType) == ResTy; 170 if (IsEqualTypes) 171 return; 172 // There is a type mismatch between results and operand types 173 // and we insert a bitcast before the instruction to keep SPIR-V code valid 174 SPIRVType *NewPtrType = 175 createNewPtrType(GR, I, OpType, IsSameMF, false, ResType, ResTy); 176 if (!GR.isBitcastCompatible(NewPtrType, OpType)) 177 report_fatal_error( 178 "insert validation bitcast: incompatible result and operand types"); 179 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); 180 } 181 182 // Insert a bitcast before OpGroupWaitEvents if the last argument is a pointer 183 // that doesn't point to OpTypeEvent. 184 static void validateGroupWaitEventsPtr(const SPIRVSubtarget &STI, 185 MachineRegisterInfo *MRI, 186 SPIRVGlobalRegistry &GR, 187 MachineInstr &I) { 188 constexpr unsigned OpIdx = 2; 189 MachineFunction *MF = I.getParent()->getParent(); 190 Register OpReg = I.getOperand(OpIdx).getReg(); 191 Register OpTypeReg = getTypeReg(MRI, OpReg); 192 SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF); 193 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer) 194 return; 195 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg()); 196 if (!ElemType || ElemType->getOpcode() == SPIRV::OpTypeEvent) 197 return; 198 // Insert a bitcast before the instruction to keep SPIR-V code valid. 199 LLVMContext &Context = MF->getFunction().getContext(); 200 SPIRVType *NewPtrType = 201 createNewPtrType(GR, I, OpType, false, true, nullptr, 202 TargetExtType::get(Context, "spirv.Event")); 203 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); 204 } 205 206 static void validateGroupAsyncCopyPtr(const SPIRVSubtarget &STI, 207 MachineRegisterInfo *MRI, 208 SPIRVGlobalRegistry &GR, MachineInstr &I, 209 unsigned OpIdx) { 210 MachineFunction *MF = I.getParent()->getParent(); 211 Register OpReg = I.getOperand(OpIdx).getReg(); 212 Register OpTypeReg = getTypeReg(MRI, OpReg); 213 SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpTypeReg, MF); 214 if (!OpType || OpType->getOpcode() != SPIRV::OpTypePointer) 215 return; 216 SPIRVType *ElemType = GR.getSPIRVTypeForVReg(OpType->getOperand(2).getReg()); 217 if (!ElemType || ElemType->getOpcode() != SPIRV::OpTypeStruct || 218 ElemType->getNumOperands() != 2) 219 return; 220 // It's a structure-wrapper around another type with a single member field. 221 SPIRVType *MemberType = 222 GR.getSPIRVTypeForVReg(ElemType->getOperand(1).getReg()); 223 if (!MemberType) 224 return; 225 unsigned MemberTypeOp = MemberType->getOpcode(); 226 if (MemberTypeOp != SPIRV::OpTypeVector && MemberTypeOp != SPIRV::OpTypeInt && 227 MemberTypeOp != SPIRV::OpTypeFloat && MemberTypeOp != SPIRV::OpTypeBool) 228 return; 229 // It's a structure-wrapper around a valid type. Insert a bitcast before the 230 // instruction to keep SPIR-V code valid. 231 SPIRV::StorageClass::StorageClass SC = 232 static_cast<SPIRV::StorageClass::StorageClass>( 233 OpType->getOperand(1).getImm()); 234 MachineIRBuilder MIB(I); 235 SPIRVType *NewPtrType = GR.getOrCreateSPIRVPointerType(MemberType, MIB, SC); 236 doInsertBitcast(STI, MRI, GR, I, OpReg, OpIdx, NewPtrType); 237 } 238 239 // Insert a bitcast before the function call instruction to keep SPIR-V code 240 // valid when there is a type mismatch between actual and expected types of an 241 // argument: 242 // %formal = OpFunctionParameter %formal_type 243 // ... 244 // %res = OpFunctionCall %ty %fun %actual ... 245 // implies that %actual is of %formal_type, and in case of opaque pointers. 246 // We may need to insert a bitcast to ensure this. 247 void validateFunCallMachineDef(const SPIRVSubtarget &STI, 248 MachineRegisterInfo *DefMRI, 249 MachineRegisterInfo *CallMRI, 250 SPIRVGlobalRegistry &GR, MachineInstr &FunCall, 251 MachineInstr *FunDef) { 252 if (FunDef->getOpcode() != SPIRV::OpFunction) 253 return; 254 unsigned OpIdx = 3; 255 for (FunDef = FunDef->getNextNode(); 256 FunDef && FunDef->getOpcode() == SPIRV::OpFunctionParameter && 257 OpIdx < FunCall.getNumOperands(); 258 FunDef = FunDef->getNextNode(), OpIdx++) { 259 SPIRVType *DefPtrType = DefMRI->getVRegDef(FunDef->getOperand(1).getReg()); 260 SPIRVType *DefElemType = 261 DefPtrType && DefPtrType->getOpcode() == SPIRV::OpTypePointer 262 ? GR.getSPIRVTypeForVReg(DefPtrType->getOperand(2).getReg(), 263 DefPtrType->getParent()->getParent()) 264 : nullptr; 265 if (DefElemType) { 266 const Type *DefElemTy = GR.getTypeForSPIRVType(DefElemType); 267 // validatePtrTypes() works in the context if the call site 268 // When we process historical records about forward calls 269 // we need to switch context to the (forward) call site and 270 // then restore it back to the current machine function. 271 MachineFunction *CurMF = 272 GR.setCurrentFunc(*FunCall.getParent()->getParent()); 273 validatePtrTypes(STI, CallMRI, GR, FunCall, OpIdx, DefElemType, 274 DefElemTy); 275 GR.setCurrentFunc(*CurMF); 276 } 277 } 278 } 279 280 // Ensure there is no mismatch between actual and expected arg types: calls 281 // with a processed definition. Return Function pointer if it's a forward 282 // call (ahead of definition), and nullptr otherwise. 283 const Function *validateFunCall(const SPIRVSubtarget &STI, 284 MachineRegisterInfo *CallMRI, 285 SPIRVGlobalRegistry &GR, 286 MachineInstr &FunCall) { 287 const GlobalValue *GV = FunCall.getOperand(2).getGlobal(); 288 const Function *F = dyn_cast<Function>(GV); 289 MachineInstr *FunDef = 290 const_cast<MachineInstr *>(GR.getFunctionDefinition(F)); 291 if (!FunDef) 292 return F; 293 MachineRegisterInfo *DefMRI = &FunDef->getParent()->getParent()->getRegInfo(); 294 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, FunCall, FunDef); 295 return nullptr; 296 } 297 298 // Ensure there is no mismatch between actual and expected arg types: calls 299 // ahead of a processed definition. 300 void validateForwardCalls(const SPIRVSubtarget &STI, 301 MachineRegisterInfo *DefMRI, SPIRVGlobalRegistry &GR, 302 MachineInstr &FunDef) { 303 const Function *F = GR.getFunctionByDefinition(&FunDef); 304 if (SmallPtrSet<MachineInstr *, 8> *FwdCalls = GR.getForwardCalls(F)) 305 for (MachineInstr *FunCall : *FwdCalls) { 306 MachineRegisterInfo *CallMRI = 307 &FunCall->getParent()->getParent()->getRegInfo(); 308 validateFunCallMachineDef(STI, DefMRI, CallMRI, GR, *FunCall, &FunDef); 309 } 310 } 311 312 // Validation of an access chain. 313 void validateAccessChain(const SPIRVSubtarget &STI, MachineRegisterInfo *MRI, 314 SPIRVGlobalRegistry &GR, MachineInstr &I) { 315 SPIRVType *BaseTypeInst = GR.getSPIRVTypeForVReg(I.getOperand(0).getReg()); 316 if (BaseTypeInst && BaseTypeInst->getOpcode() == SPIRV::OpTypePointer) { 317 SPIRVType *BaseElemType = 318 GR.getSPIRVTypeForVReg(BaseTypeInst->getOperand(2).getReg()); 319 validatePtrTypes(STI, MRI, GR, I, 2, BaseElemType); 320 } 321 } 322 323 // TODO: the logic of inserting additional bitcast's is to be moved 324 // to pre-IRTranslation passes eventually 325 void SPIRVTargetLowering::finalizeLowering(MachineFunction &MF) const { 326 // finalizeLowering() is called twice (see GlobalISel/InstructionSelect.cpp) 327 // We'd like to avoid the needless second processing pass. 328 if (ProcessedMF.find(&MF) != ProcessedMF.end()) 329 return; 330 331 MachineRegisterInfo *MRI = &MF.getRegInfo(); 332 SPIRVGlobalRegistry &GR = *STI.getSPIRVGlobalRegistry(); 333 GR.setCurrentFunc(MF); 334 for (MachineFunction::iterator I = MF.begin(), E = MF.end(); I != E; ++I) { 335 MachineBasicBlock *MBB = &*I; 336 for (MachineBasicBlock::iterator MBBI = MBB->begin(), MBBE = MBB->end(); 337 MBBI != MBBE;) { 338 MachineInstr &MI = *MBBI++; 339 switch (MI.getOpcode()) { 340 case SPIRV::OpAtomicLoad: 341 case SPIRV::OpAtomicExchange: 342 case SPIRV::OpAtomicCompareExchange: 343 case SPIRV::OpAtomicCompareExchangeWeak: 344 case SPIRV::OpAtomicIIncrement: 345 case SPIRV::OpAtomicIDecrement: 346 case SPIRV::OpAtomicIAdd: 347 case SPIRV::OpAtomicISub: 348 case SPIRV::OpAtomicSMin: 349 case SPIRV::OpAtomicUMin: 350 case SPIRV::OpAtomicSMax: 351 case SPIRV::OpAtomicUMax: 352 case SPIRV::OpAtomicAnd: 353 case SPIRV::OpAtomicOr: 354 case SPIRV::OpAtomicXor: 355 // for the above listed instructions 356 // OpAtomicXXX <ResType>, ptr %Op, ... 357 // implies that %Op is a pointer to <ResType> 358 case SPIRV::OpLoad: 359 // OpLoad <ResType>, ptr %Op implies that %Op is a pointer to <ResType> 360 validatePtrTypes(STI, MRI, GR, MI, 2, 361 GR.getSPIRVTypeForVReg(MI.getOperand(0).getReg())); 362 break; 363 case SPIRV::OpAtomicStore: 364 // OpAtomicStore ptr %Op, <Scope>, <Mem>, <Obj> 365 // implies that %Op points to the <Obj>'s type 366 validatePtrTypes(STI, MRI, GR, MI, 0, 367 GR.getSPIRVTypeForVReg(MI.getOperand(3).getReg())); 368 break; 369 case SPIRV::OpStore: 370 // OpStore ptr %Op, <Obj> implies that %Op points to the <Obj>'s type 371 validatePtrTypes(STI, MRI, GR, MI, 0, 372 GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg())); 373 break; 374 case SPIRV::OpPtrCastToGeneric: 375 case SPIRV::OpGenericCastToPtr: 376 validateAccessChain(STI, MRI, GR, MI); 377 break; 378 case SPIRV::OpInBoundsPtrAccessChain: 379 if (MI.getNumOperands() == 4) 380 validateAccessChain(STI, MRI, GR, MI); 381 break; 382 383 case SPIRV::OpFunctionCall: 384 // ensure there is no mismatch between actual and expected arg types: 385 // calls with a processed definition 386 if (MI.getNumOperands() > 3) 387 if (const Function *F = validateFunCall(STI, MRI, GR, MI)) 388 GR.addForwardCall(F, &MI); 389 break; 390 case SPIRV::OpFunction: 391 // ensure there is no mismatch between actual and expected arg types: 392 // calls ahead of a processed definition 393 validateForwardCalls(STI, MRI, GR, MI); 394 break; 395 396 // ensure that LLVM IR bitwise instructions result in logical SPIR-V 397 // instructions when applied to bool type 398 case SPIRV::OpBitwiseOrS: 399 case SPIRV::OpBitwiseOrV: 400 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(), 401 SPIRV::OpTypeBool)) 402 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalOr)); 403 break; 404 case SPIRV::OpBitwiseAndS: 405 case SPIRV::OpBitwiseAndV: 406 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(), 407 SPIRV::OpTypeBool)) 408 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalAnd)); 409 break; 410 case SPIRV::OpBitwiseXorS: 411 case SPIRV::OpBitwiseXorV: 412 if (GR.isScalarOrVectorOfType(MI.getOperand(1).getReg(), 413 SPIRV::OpTypeBool)) 414 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpLogicalNotEqual)); 415 break; 416 case SPIRV::OpGroupAsyncCopy: 417 validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 3); 418 validateGroupAsyncCopyPtr(STI, MRI, GR, MI, 4); 419 break; 420 case SPIRV::OpGroupWaitEvents: 421 // OpGroupWaitEvents ..., ..., <pointer to OpTypeEvent> 422 validateGroupWaitEventsPtr(STI, MRI, GR, MI); 423 break; 424 case SPIRV::OpConstantI: { 425 SPIRVType *Type = GR.getSPIRVTypeForVReg(MI.getOperand(1).getReg()); 426 if (Type->getOpcode() != SPIRV::OpTypeInt && MI.getOperand(2).isImm() && 427 MI.getOperand(2).getImm() == 0) { 428 // Validate the null constant of a target extension type 429 MI.setDesc(STI.getInstrInfo()->get(SPIRV::OpConstantNull)); 430 for (unsigned i = MI.getNumOperands() - 1; i > 1; --i) 431 MI.removeOperand(i); 432 } 433 } break; 434 } 435 } 436 } 437 ProcessedMF.insert(&MF); 438 TargetLowering::finalizeLowering(MF); 439 } 440