1 //===--- SPIRVUtils.cpp ---- SPIR-V Utility Functions -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains miscellaneous utility functions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "SPIRVUtils.h" 14 #include "MCTargetDesc/SPIRVBaseInfo.h" 15 #include "SPIRV.h" 16 #include "SPIRVInstrInfo.h" 17 #include "SPIRVSubtarget.h" 18 #include "llvm/ADT/StringRef.h" 19 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" 20 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 21 #include "llvm/CodeGen/MachineInstr.h" 22 #include "llvm/CodeGen/MachineInstrBuilder.h" 23 #include "llvm/Demangle/Demangle.h" 24 #include "llvm/IR/IntrinsicsSPIRV.h" 25 26 namespace llvm { 27 28 // The following functions are used to add these string literals as a series of 29 // 32-bit integer operands with the correct format, and unpack them if necessary 30 // when making string comparisons in compiler passes. 31 // SPIR-V requires null-terminated UTF-8 strings padded to 32-bit alignment. 32 static uint32_t convertCharsToWord(const StringRef &Str, unsigned i) { 33 uint32_t Word = 0u; // Build up this 32-bit word from 4 8-bit chars. 34 for (unsigned WordIndex = 0; WordIndex < 4; ++WordIndex) { 35 unsigned StrIndex = i + WordIndex; 36 uint8_t CharToAdd = 0; // Initilize char as padding/null. 37 if (StrIndex < Str.size()) { // If it's within the string, get a real char. 38 CharToAdd = Str[StrIndex]; 39 } 40 Word |= (CharToAdd << (WordIndex * 8)); 41 } 42 return Word; 43 } 44 45 // Get length including padding and null terminator. 46 static size_t getPaddedLen(const StringRef &Str) { 47 const size_t Len = Str.size() + 1; 48 return (Len % 4 == 0) ? Len : Len + (4 - (Len % 4)); 49 } 50 51 void addStringImm(const StringRef &Str, MCInst &Inst) { 52 const size_t PaddedLen = getPaddedLen(Str); 53 for (unsigned i = 0; i < PaddedLen; i += 4) { 54 // Add an operand for the 32-bits of chars or padding. 55 Inst.addOperand(MCOperand::createImm(convertCharsToWord(Str, i))); 56 } 57 } 58 59 void addStringImm(const StringRef &Str, MachineInstrBuilder &MIB) { 60 const size_t PaddedLen = getPaddedLen(Str); 61 for (unsigned i = 0; i < PaddedLen; i += 4) { 62 // Add an operand for the 32-bits of chars or padding. 63 MIB.addImm(convertCharsToWord(Str, i)); 64 } 65 } 66 67 void addStringImm(const StringRef &Str, IRBuilder<> &B, 68 std::vector<Value *> &Args) { 69 const size_t PaddedLen = getPaddedLen(Str); 70 for (unsigned i = 0; i < PaddedLen; i += 4) { 71 // Add a vector element for the 32-bits of chars or padding. 72 Args.push_back(B.getInt32(convertCharsToWord(Str, i))); 73 } 74 } 75 76 std::string getStringImm(const MachineInstr &MI, unsigned StartIndex) { 77 return getSPIRVStringOperand(MI, StartIndex); 78 } 79 80 void addNumImm(const APInt &Imm, MachineInstrBuilder &MIB) { 81 const auto Bitwidth = Imm.getBitWidth(); 82 if (Bitwidth == 1) 83 return; // Already handled 84 else if (Bitwidth <= 32) { 85 MIB.addImm(Imm.getZExtValue()); 86 // Asm Printer needs this info to print floating-type correctly 87 if (Bitwidth == 16) 88 MIB.getInstr()->setAsmPrinterFlag(SPIRV::ASM_PRINTER_WIDTH16); 89 return; 90 } else if (Bitwidth <= 64) { 91 uint64_t FullImm = Imm.getZExtValue(); 92 uint32_t LowBits = FullImm & 0xffffffff; 93 uint32_t HighBits = (FullImm >> 32) & 0xffffffff; 94 MIB.addImm(LowBits).addImm(HighBits); 95 return; 96 } 97 report_fatal_error("Unsupported constant bitwidth"); 98 } 99 100 void buildOpName(Register Target, const StringRef &Name, 101 MachineIRBuilder &MIRBuilder) { 102 if (!Name.empty()) { 103 auto MIB = MIRBuilder.buildInstr(SPIRV::OpName).addUse(Target); 104 addStringImm(Name, MIB); 105 } 106 } 107 108 static void finishBuildOpDecorate(MachineInstrBuilder &MIB, 109 const std::vector<uint32_t> &DecArgs, 110 StringRef StrImm) { 111 if (!StrImm.empty()) 112 addStringImm(StrImm, MIB); 113 for (const auto &DecArg : DecArgs) 114 MIB.addImm(DecArg); 115 } 116 117 void buildOpDecorate(Register Reg, MachineIRBuilder &MIRBuilder, 118 SPIRV::Decoration::Decoration Dec, 119 const std::vector<uint32_t> &DecArgs, StringRef StrImm) { 120 auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate) 121 .addUse(Reg) 122 .addImm(static_cast<uint32_t>(Dec)); 123 finishBuildOpDecorate(MIB, DecArgs, StrImm); 124 } 125 126 void buildOpDecorate(Register Reg, MachineInstr &I, const SPIRVInstrInfo &TII, 127 SPIRV::Decoration::Decoration Dec, 128 const std::vector<uint32_t> &DecArgs, StringRef StrImm) { 129 MachineBasicBlock &MBB = *I.getParent(); 130 auto MIB = BuildMI(MBB, I, I.getDebugLoc(), TII.get(SPIRV::OpDecorate)) 131 .addUse(Reg) 132 .addImm(static_cast<uint32_t>(Dec)); 133 finishBuildOpDecorate(MIB, DecArgs, StrImm); 134 } 135 136 void buildOpSpirvDecorations(Register Reg, MachineIRBuilder &MIRBuilder, 137 const MDNode *GVarMD) { 138 for (unsigned I = 0, E = GVarMD->getNumOperands(); I != E; ++I) { 139 auto *OpMD = dyn_cast<MDNode>(GVarMD->getOperand(I)); 140 if (!OpMD) 141 report_fatal_error("Invalid decoration"); 142 if (OpMD->getNumOperands() == 0) 143 report_fatal_error("Expect operand(s) of the decoration"); 144 ConstantInt *DecorationId = 145 mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(0)); 146 if (!DecorationId) 147 report_fatal_error("Expect SPIR-V <Decoration> operand to be the first " 148 "element of the decoration"); 149 auto MIB = MIRBuilder.buildInstr(SPIRV::OpDecorate) 150 .addUse(Reg) 151 .addImm(static_cast<uint32_t>(DecorationId->getZExtValue())); 152 for (unsigned OpI = 1, OpE = OpMD->getNumOperands(); OpI != OpE; ++OpI) { 153 if (ConstantInt *OpV = 154 mdconst::dyn_extract<ConstantInt>(OpMD->getOperand(OpI))) 155 MIB.addImm(static_cast<uint32_t>(OpV->getZExtValue())); 156 else if (MDString *OpV = dyn_cast<MDString>(OpMD->getOperand(OpI))) 157 addStringImm(OpV->getString(), MIB); 158 else 159 report_fatal_error("Unexpected operand of the decoration"); 160 } 161 } 162 } 163 164 // TODO: maybe the following two functions should be handled in the subtarget 165 // to allow for different OpenCL vs Vulkan handling. 166 unsigned storageClassToAddressSpace(SPIRV::StorageClass::StorageClass SC) { 167 switch (SC) { 168 case SPIRV::StorageClass::Function: 169 return 0; 170 case SPIRV::StorageClass::CrossWorkgroup: 171 return 1; 172 case SPIRV::StorageClass::UniformConstant: 173 return 2; 174 case SPIRV::StorageClass::Workgroup: 175 return 3; 176 case SPIRV::StorageClass::Generic: 177 return 4; 178 case SPIRV::StorageClass::DeviceOnlyINTEL: 179 return 5; 180 case SPIRV::StorageClass::HostOnlyINTEL: 181 return 6; 182 case SPIRV::StorageClass::Input: 183 return 7; 184 default: 185 report_fatal_error("Unable to get address space id"); 186 } 187 } 188 189 SPIRV::StorageClass::StorageClass 190 addressSpaceToStorageClass(unsigned AddrSpace, const SPIRVSubtarget &STI) { 191 switch (AddrSpace) { 192 case 0: 193 return SPIRV::StorageClass::Function; 194 case 1: 195 return SPIRV::StorageClass::CrossWorkgroup; 196 case 2: 197 return SPIRV::StorageClass::UniformConstant; 198 case 3: 199 return SPIRV::StorageClass::Workgroup; 200 case 4: 201 return SPIRV::StorageClass::Generic; 202 case 5: 203 return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes) 204 ? SPIRV::StorageClass::DeviceOnlyINTEL 205 : SPIRV::StorageClass::CrossWorkgroup; 206 case 6: 207 return STI.canUseExtension(SPIRV::Extension::SPV_INTEL_usm_storage_classes) 208 ? SPIRV::StorageClass::HostOnlyINTEL 209 : SPIRV::StorageClass::CrossWorkgroup; 210 case 7: 211 return SPIRV::StorageClass::Input; 212 default: 213 report_fatal_error("Unknown address space"); 214 } 215 } 216 217 SPIRV::MemorySemantics::MemorySemantics 218 getMemSemanticsForStorageClass(SPIRV::StorageClass::StorageClass SC) { 219 switch (SC) { 220 case SPIRV::StorageClass::StorageBuffer: 221 case SPIRV::StorageClass::Uniform: 222 return SPIRV::MemorySemantics::UniformMemory; 223 case SPIRV::StorageClass::Workgroup: 224 return SPIRV::MemorySemantics::WorkgroupMemory; 225 case SPIRV::StorageClass::CrossWorkgroup: 226 return SPIRV::MemorySemantics::CrossWorkgroupMemory; 227 case SPIRV::StorageClass::AtomicCounter: 228 return SPIRV::MemorySemantics::AtomicCounterMemory; 229 case SPIRV::StorageClass::Image: 230 return SPIRV::MemorySemantics::ImageMemory; 231 default: 232 return SPIRV::MemorySemantics::None; 233 } 234 } 235 236 SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) { 237 switch (Ord) { 238 case AtomicOrdering::Acquire: 239 return SPIRV::MemorySemantics::Acquire; 240 case AtomicOrdering::Release: 241 return SPIRV::MemorySemantics::Release; 242 case AtomicOrdering::AcquireRelease: 243 return SPIRV::MemorySemantics::AcquireRelease; 244 case AtomicOrdering::SequentiallyConsistent: 245 return SPIRV::MemorySemantics::SequentiallyConsistent; 246 case AtomicOrdering::Unordered: 247 case AtomicOrdering::Monotonic: 248 case AtomicOrdering::NotAtomic: 249 return SPIRV::MemorySemantics::None; 250 } 251 llvm_unreachable(nullptr); 252 } 253 254 MachineInstr *getDefInstrMaybeConstant(Register &ConstReg, 255 const MachineRegisterInfo *MRI) { 256 MachineInstr *MI = MRI->getVRegDef(ConstReg); 257 MachineInstr *ConstInstr = 258 MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT 259 ? MRI->getVRegDef(MI->getOperand(1).getReg()) 260 : MI; 261 if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) { 262 if (GI->is(Intrinsic::spv_track_constant)) { 263 ConstReg = ConstInstr->getOperand(2).getReg(); 264 return MRI->getVRegDef(ConstReg); 265 } 266 } else if (ConstInstr->getOpcode() == SPIRV::ASSIGN_TYPE) { 267 ConstReg = ConstInstr->getOperand(1).getReg(); 268 return MRI->getVRegDef(ConstReg); 269 } 270 return MRI->getVRegDef(ConstReg); 271 } 272 273 uint64_t getIConstVal(Register ConstReg, const MachineRegisterInfo *MRI) { 274 const MachineInstr *MI = getDefInstrMaybeConstant(ConstReg, MRI); 275 assert(MI && MI->getOpcode() == TargetOpcode::G_CONSTANT); 276 return MI->getOperand(1).getCImm()->getValue().getZExtValue(); 277 } 278 279 bool isSpvIntrinsic(const MachineInstr &MI, Intrinsic::ID IntrinsicID) { 280 if (const auto *GI = dyn_cast<GIntrinsic>(&MI)) 281 return GI->is(IntrinsicID); 282 return false; 283 } 284 285 Type *getMDOperandAsType(const MDNode *N, unsigned I) { 286 Type *ElementTy = cast<ValueAsMetadata>(N->getOperand(I))->getType(); 287 return toTypedPointer(ElementTy); 288 } 289 290 // The set of names is borrowed from the SPIR-V translator. 291 // TODO: may be implemented in SPIRVBuiltins.td. 292 static bool isPipeOrAddressSpaceCastBI(const StringRef MangledName) { 293 return MangledName == "write_pipe_2" || MangledName == "read_pipe_2" || 294 MangledName == "write_pipe_2_bl" || MangledName == "read_pipe_2_bl" || 295 MangledName == "write_pipe_4" || MangledName == "read_pipe_4" || 296 MangledName == "reserve_write_pipe" || 297 MangledName == "reserve_read_pipe" || 298 MangledName == "commit_write_pipe" || 299 MangledName == "commit_read_pipe" || 300 MangledName == "work_group_reserve_write_pipe" || 301 MangledName == "work_group_reserve_read_pipe" || 302 MangledName == "work_group_commit_write_pipe" || 303 MangledName == "work_group_commit_read_pipe" || 304 MangledName == "get_pipe_num_packets_ro" || 305 MangledName == "get_pipe_max_packets_ro" || 306 MangledName == "get_pipe_num_packets_wo" || 307 MangledName == "get_pipe_max_packets_wo" || 308 MangledName == "sub_group_reserve_write_pipe" || 309 MangledName == "sub_group_reserve_read_pipe" || 310 MangledName == "sub_group_commit_write_pipe" || 311 MangledName == "sub_group_commit_read_pipe" || 312 MangledName == "to_global" || MangledName == "to_local" || 313 MangledName == "to_private"; 314 } 315 316 static bool isEnqueueKernelBI(const StringRef MangledName) { 317 return MangledName == "__enqueue_kernel_basic" || 318 MangledName == "__enqueue_kernel_basic_events" || 319 MangledName == "__enqueue_kernel_varargs" || 320 MangledName == "__enqueue_kernel_events_varargs"; 321 } 322 323 static bool isKernelQueryBI(const StringRef MangledName) { 324 return MangledName == "__get_kernel_work_group_size_impl" || 325 MangledName == "__get_kernel_sub_group_count_for_ndrange_impl" || 326 MangledName == "__get_kernel_max_sub_group_size_for_ndrange_impl" || 327 MangledName == "__get_kernel_preferred_work_group_size_multiple_impl"; 328 } 329 330 static bool isNonMangledOCLBuiltin(StringRef Name) { 331 if (!Name.starts_with("__")) 332 return false; 333 334 return isEnqueueKernelBI(Name) || isKernelQueryBI(Name) || 335 isPipeOrAddressSpaceCastBI(Name.drop_front(2)) || 336 Name == "__translate_sampler_initializer"; 337 } 338 339 std::string getOclOrSpirvBuiltinDemangledName(StringRef Name) { 340 bool IsNonMangledOCL = isNonMangledOCLBuiltin(Name); 341 bool IsNonMangledSPIRV = Name.starts_with("__spirv_"); 342 bool IsNonMangledHLSL = Name.starts_with("__hlsl_"); 343 bool IsMangled = Name.starts_with("_Z"); 344 345 // Otherwise use simple demangling to return the function name. 346 if (IsNonMangledOCL || IsNonMangledSPIRV || IsNonMangledHLSL || !IsMangled) 347 return Name.str(); 348 349 // Try to use the itanium demangler. 350 if (char *DemangledName = itaniumDemangle(Name.data())) { 351 std::string Result = DemangledName; 352 free(DemangledName); 353 return Result; 354 } 355 356 // Autocheck C++, maybe need to do explicit check of the source language. 357 // OpenCL C++ built-ins are declared in cl namespace. 358 // TODO: consider using 'St' abbriviation for cl namespace mangling. 359 // Similar to ::std:: in C++. 360 size_t Start, Len = 0; 361 size_t DemangledNameLenStart = 2; 362 if (Name.starts_with("_ZN")) { 363 // Skip CV and ref qualifiers. 364 size_t NameSpaceStart = Name.find_first_not_of("rVKRO", 3); 365 // All built-ins are in the ::cl:: namespace. 366 if (Name.substr(NameSpaceStart, 11) != "2cl7__spirv") 367 return std::string(); 368 DemangledNameLenStart = NameSpaceStart + 11; 369 } 370 Start = Name.find_first_not_of("0123456789", DemangledNameLenStart); 371 Name.substr(DemangledNameLenStart, Start - DemangledNameLenStart) 372 .getAsInteger(10, Len); 373 return Name.substr(Start, Len).str(); 374 } 375 376 bool hasBuiltinTypePrefix(StringRef Name) { 377 if (Name.starts_with("opencl.") || Name.starts_with("ocl_") || 378 Name.starts_with("spirv.")) 379 return true; 380 return false; 381 } 382 383 bool isSpecialOpaqueType(const Type *Ty) { 384 if (const TargetExtType *EType = dyn_cast<TargetExtType>(Ty)) 385 return hasBuiltinTypePrefix(EType->getName()); 386 387 return false; 388 } 389 390 bool isEntryPoint(const Function &F) { 391 // OpenCL handling: any function with the SPIR_KERNEL 392 // calling convention will be a potential entry point. 393 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) 394 return true; 395 396 // HLSL handling: special attribute are emitted from the 397 // front-end. 398 if (F.getFnAttribute("hlsl.shader").isValid()) 399 return true; 400 401 return false; 402 } 403 404 Type *parseBasicTypeName(StringRef &TypeName, LLVMContext &Ctx) { 405 TypeName.consume_front("atomic_"); 406 if (TypeName.consume_front("void")) 407 return Type::getVoidTy(Ctx); 408 else if (TypeName.consume_front("bool")) 409 return Type::getIntNTy(Ctx, 1); 410 else if (TypeName.consume_front("char") || 411 TypeName.consume_front("unsigned char") || 412 TypeName.consume_front("uchar")) 413 return Type::getInt8Ty(Ctx); 414 else if (TypeName.consume_front("short") || 415 TypeName.consume_front("unsigned short") || 416 TypeName.consume_front("ushort")) 417 return Type::getInt16Ty(Ctx); 418 else if (TypeName.consume_front("int") || 419 TypeName.consume_front("unsigned int") || 420 TypeName.consume_front("uint")) 421 return Type::getInt32Ty(Ctx); 422 else if (TypeName.consume_front("long") || 423 TypeName.consume_front("unsigned long") || 424 TypeName.consume_front("ulong")) 425 return Type::getInt64Ty(Ctx); 426 else if (TypeName.consume_front("half")) 427 return Type::getHalfTy(Ctx); 428 else if (TypeName.consume_front("float")) 429 return Type::getFloatTy(Ctx); 430 else if (TypeName.consume_front("double")) 431 return Type::getDoubleTy(Ctx); 432 433 // Unable to recognize SPIRV type name 434 return nullptr; 435 } 436 437 } // namespace llvm 438