xref: /freebsd/contrib/llvm-project/clang/lib/CodeGen/Targets/SPIR.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- SPIR.cpp -----------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ABIInfoImpl.h"
10 #include "HLSLBufferLayoutBuilder.h"
11 #include "TargetInfo.h"
12 
13 using namespace clang;
14 using namespace clang::CodeGen;
15 
16 //===----------------------------------------------------------------------===//
17 // Base ABI and target codegen info implementation common between SPIR and
18 // SPIR-V.
19 //===----------------------------------------------------------------------===//
20 
21 namespace {
22 class CommonSPIRABIInfo : public DefaultABIInfo {
23 public:
CommonSPIRABIInfo(CodeGenTypes & CGT)24   CommonSPIRABIInfo(CodeGenTypes &CGT) : DefaultABIInfo(CGT) { setCCs(); }
25 
26 private:
27   void setCCs();
28 };
29 
30 class SPIRVABIInfo : public CommonSPIRABIInfo {
31 public:
SPIRVABIInfo(CodeGenTypes & CGT)32   SPIRVABIInfo(CodeGenTypes &CGT) : CommonSPIRABIInfo(CGT) {}
33   void computeInfo(CGFunctionInfo &FI) const override;
34 
35 private:
36   ABIArgInfo classifyReturnType(QualType RetTy) const;
37   ABIArgInfo classifyKernelArgumentType(QualType Ty) const;
38   ABIArgInfo classifyArgumentType(QualType Ty) const;
39 };
40 } // end anonymous namespace
41 namespace {
42 class CommonSPIRTargetCodeGenInfo : public TargetCodeGenInfo {
43 public:
CommonSPIRTargetCodeGenInfo(CodeGen::CodeGenTypes & CGT)44   CommonSPIRTargetCodeGenInfo(CodeGen::CodeGenTypes &CGT)
45       : TargetCodeGenInfo(std::make_unique<CommonSPIRABIInfo>(CGT)) {}
CommonSPIRTargetCodeGenInfo(std::unique_ptr<ABIInfo> ABIInfo)46   CommonSPIRTargetCodeGenInfo(std::unique_ptr<ABIInfo> ABIInfo)
47       : TargetCodeGenInfo(std::move(ABIInfo)) {}
48 
getASTAllocaAddressSpace() const49   LangAS getASTAllocaAddressSpace() const override {
50     return getLangASFromTargetAS(
51         getABIInfo().getDataLayout().getAllocaAddrSpace());
52   }
53 
54   unsigned getDeviceKernelCallingConv() const override;
55   llvm::Type *getOpenCLType(CodeGenModule &CGM, const Type *T) const override;
56   llvm::Type *
57   getHLSLType(CodeGenModule &CGM, const Type *Ty,
58               const SmallVector<int32_t> *Packoffsets = nullptr) const override;
59   llvm::Type *getSPIRVImageTypeFromHLSLResource(
60       const HLSLAttributedResourceType::Attributes &attributes,
61       QualType SampledType, CodeGenModule &CGM) const;
62   void
63   setOCLKernelStubCallingConvention(const FunctionType *&FT) const override;
64 };
65 class SPIRVTargetCodeGenInfo : public CommonSPIRTargetCodeGenInfo {
66 public:
SPIRVTargetCodeGenInfo(CodeGen::CodeGenTypes & CGT)67   SPIRVTargetCodeGenInfo(CodeGen::CodeGenTypes &CGT)
68       : CommonSPIRTargetCodeGenInfo(std::make_unique<SPIRVABIInfo>(CGT)) {}
69   void setCUDAKernelCallingConvention(const FunctionType *&FT) const override;
70   LangAS getGlobalVarAddressSpace(CodeGenModule &CGM,
71                                   const VarDecl *D) const override;
72   void setTargetAttributes(const Decl *D, llvm::GlobalValue *GV,
73                            CodeGen::CodeGenModule &M) const override;
74   llvm::SyncScope::ID getLLVMSyncScopeID(const LangOptions &LangOpts,
75                                          SyncScope Scope,
76                                          llvm::AtomicOrdering Ordering,
77                                          llvm::LLVMContext &Ctx) const override;
supportsLibCall() const78   bool supportsLibCall() const override {
79     return getABIInfo().getTarget().getTriple().getVendor() !=
80            llvm::Triple::AMD;
81   }
82 };
83 
mapClangSyncScopeToLLVM(SyncScope Scope)84 inline StringRef mapClangSyncScopeToLLVM(SyncScope Scope) {
85   switch (Scope) {
86   case SyncScope::HIPSingleThread:
87   case SyncScope::SingleScope:
88     return "singlethread";
89   case SyncScope::HIPWavefront:
90   case SyncScope::OpenCLSubGroup:
91   case SyncScope::WavefrontScope:
92     return "subgroup";
93   case SyncScope::HIPWorkgroup:
94   case SyncScope::OpenCLWorkGroup:
95   case SyncScope::WorkgroupScope:
96     return "workgroup";
97   case SyncScope::HIPAgent:
98   case SyncScope::OpenCLDevice:
99   case SyncScope::DeviceScope:
100     return "device";
101   case SyncScope::SystemScope:
102   case SyncScope::HIPSystem:
103   case SyncScope::OpenCLAllSVMDevices:
104     return "";
105   }
106   return "";
107 }
108 } // End anonymous namespace.
109 
setCCs()110 void CommonSPIRABIInfo::setCCs() {
111   assert(getRuntimeCC() == llvm::CallingConv::C);
112   RuntimeCC = llvm::CallingConv::SPIR_FUNC;
113 }
114 
classifyReturnType(QualType RetTy) const115 ABIArgInfo SPIRVABIInfo::classifyReturnType(QualType RetTy) const {
116   if (getTarget().getTriple().getVendor() != llvm::Triple::AMD)
117     return DefaultABIInfo::classifyReturnType(RetTy);
118   if (!isAggregateTypeForABI(RetTy) || getRecordArgABI(RetTy, getCXXABI()))
119     return DefaultABIInfo::classifyReturnType(RetTy);
120 
121   if (const RecordType *RT = RetTy->getAs<RecordType>()) {
122     const RecordDecl *RD = RT->getDecl();
123     if (RD->hasFlexibleArrayMember())
124       return DefaultABIInfo::classifyReturnType(RetTy);
125   }
126 
127   // TODO: The AMDGPU ABI is non-trivial to represent in SPIR-V; in order to
128   // avoid encoding various architecture specific bits here we return everything
129   // as direct to retain type info for things like aggregates, for later perusal
130   // when translating back to LLVM/lowering in the BE. This is also why we
131   // disable flattening as the outcomes can mismatch between SPIR-V and AMDGPU.
132   // This will be revisited / optimised in the future.
133   return ABIArgInfo::getDirect(CGT.ConvertType(RetTy), 0u, nullptr, false);
134 }
135 
classifyKernelArgumentType(QualType Ty) const136 ABIArgInfo SPIRVABIInfo::classifyKernelArgumentType(QualType Ty) const {
137   if (getContext().getLangOpts().CUDAIsDevice) {
138     // Coerce pointer arguments with default address space to CrossWorkGroup
139     // pointers for HIPSPV/CUDASPV. When the language mode is HIP/CUDA, the
140     // SPIRTargetInfo maps cuda_device to SPIR-V's CrossWorkGroup address space.
141     llvm::Type *LTy = CGT.ConvertType(Ty);
142     auto DefaultAS = getContext().getTargetAddressSpace(LangAS::Default);
143     auto GlobalAS = getContext().getTargetAddressSpace(LangAS::cuda_device);
144     auto *PtrTy = llvm::dyn_cast<llvm::PointerType>(LTy);
145     if (PtrTy && PtrTy->getAddressSpace() == DefaultAS) {
146       LTy = llvm::PointerType::get(PtrTy->getContext(), GlobalAS);
147       return ABIArgInfo::getDirect(LTy, 0, nullptr, false);
148     }
149 
150     if (isAggregateTypeForABI(Ty)) {
151       if (getTarget().getTriple().getVendor() == llvm::Triple::AMD)
152         // TODO: The AMDGPU kernel ABI passes aggregates byref, which is not
153         // currently expressible in SPIR-V; SPIR-V passes aggregates byval,
154         // which the AMDGPU kernel ABI does not allow. Passing aggregates as
155         // direct works around this impedance mismatch, as it retains type info
156         // and can be correctly handled, post reverse-translation, by the AMDGPU
157         // BE, which has to support this CC for legacy OpenCL purposes. It can
158         // be brittle and does lead to performance degradation in certain
159         // pathological cases. This will be revisited / optimised in the future,
160         // once a way to deal with the byref/byval impedance mismatch is
161         // identified.
162         return ABIArgInfo::getDirect(LTy, 0, nullptr, false);
163       // Force copying aggregate type in kernel arguments by value when
164       // compiling CUDA targeting SPIR-V. This is required for the object
165       // copied to be valid on the device.
166       // This behavior follows the CUDA spec
167       // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#global-function-argument-processing,
168       // and matches the NVPTX implementation. TODO: hardcoding to 0 should be
169       // revisited if HIPSPV / byval starts making use of the AS of an indirect
170       // arg.
171       return getNaturalAlignIndirect(Ty, /*AddrSpace=*/0, /*byval=*/true);
172     }
173   }
174   return classifyArgumentType(Ty);
175 }
176 
classifyArgumentType(QualType Ty) const177 ABIArgInfo SPIRVABIInfo::classifyArgumentType(QualType Ty) const {
178   if (getTarget().getTriple().getVendor() != llvm::Triple::AMD)
179     return DefaultABIInfo::classifyArgumentType(Ty);
180   if (!isAggregateTypeForABI(Ty))
181     return DefaultABIInfo::classifyArgumentType(Ty);
182 
183   // Records with non-trivial destructors/copy-constructors should not be
184   // passed by value.
185   if (auto RAA = getRecordArgABI(Ty, getCXXABI()))
186     return getNaturalAlignIndirect(Ty, getDataLayout().getAllocaAddrSpace(),
187                                    RAA == CGCXXABI::RAA_DirectInMemory);
188 
189   if (const RecordType *RT = Ty->getAs<RecordType>()) {
190     const RecordDecl *RD = RT->getDecl();
191     if (RD->hasFlexibleArrayMember())
192       return DefaultABIInfo::classifyArgumentType(Ty);
193   }
194 
195   return ABIArgInfo::getDirect(CGT.ConvertType(Ty), 0u, nullptr, false);
196 }
197 
computeInfo(CGFunctionInfo & FI) const198 void SPIRVABIInfo::computeInfo(CGFunctionInfo &FI) const {
199   // The logic is same as in DefaultABIInfo with an exception on the kernel
200   // arguments handling.
201   llvm::CallingConv::ID CC = FI.getCallingConvention();
202 
203   if (!getCXXABI().classifyReturnType(FI))
204     FI.getReturnInfo() = classifyReturnType(FI.getReturnType());
205 
206   for (auto &I : FI.arguments()) {
207     if (CC == llvm::CallingConv::SPIR_KERNEL) {
208       I.info = classifyKernelArgumentType(I.type);
209     } else {
210       I.info = classifyArgumentType(I.type);
211     }
212   }
213 }
214 
215 namespace clang {
216 namespace CodeGen {
computeSPIRKernelABIInfo(CodeGenModule & CGM,CGFunctionInfo & FI)217 void computeSPIRKernelABIInfo(CodeGenModule &CGM, CGFunctionInfo &FI) {
218   if (CGM.getTarget().getTriple().isSPIRV())
219     SPIRVABIInfo(CGM.getTypes()).computeInfo(FI);
220   else
221     CommonSPIRABIInfo(CGM.getTypes()).computeInfo(FI);
222 }
223 }
224 }
225 
getDeviceKernelCallingConv() const226 unsigned CommonSPIRTargetCodeGenInfo::getDeviceKernelCallingConv() const {
227   return llvm::CallingConv::SPIR_KERNEL;
228 }
229 
setCUDAKernelCallingConvention(const FunctionType * & FT) const230 void SPIRVTargetCodeGenInfo::setCUDAKernelCallingConvention(
231     const FunctionType *&FT) const {
232   // Convert HIP kernels to SPIR-V kernels.
233   if (getABIInfo().getContext().getLangOpts().HIP) {
234     FT = getABIInfo().getContext().adjustFunctionType(
235         FT, FT->getExtInfo().withCallingConv(CC_DeviceKernel));
236     return;
237   }
238 }
239 
setOCLKernelStubCallingConvention(const FunctionType * & FT) const240 void CommonSPIRTargetCodeGenInfo::setOCLKernelStubCallingConvention(
241     const FunctionType *&FT) const {
242   FT = getABIInfo().getContext().adjustFunctionType(
243       FT, FT->getExtInfo().withCallingConv(CC_SpirFunction));
244 }
245 
246 LangAS
getGlobalVarAddressSpace(CodeGenModule & CGM,const VarDecl * D) const247 SPIRVTargetCodeGenInfo::getGlobalVarAddressSpace(CodeGenModule &CGM,
248                                                  const VarDecl *D) const {
249   assert(!CGM.getLangOpts().OpenCL &&
250          !(CGM.getLangOpts().CUDA && CGM.getLangOpts().CUDAIsDevice) &&
251          "Address space agnostic languages only");
252   // If we're here it means that we're using the SPIRDefIsGen ASMap, hence for
253   // the global AS we can rely on either cuda_device or sycl_global to be
254   // correct; however, since this is not a CUDA Device context, we use
255   // sycl_global to prevent confusion with the assertion.
256   LangAS DefaultGlobalAS = getLangASFromTargetAS(
257       CGM.getContext().getTargetAddressSpace(LangAS::sycl_global));
258   if (!D)
259     return DefaultGlobalAS;
260 
261   LangAS AddrSpace = D->getType().getAddressSpace();
262   if (AddrSpace != LangAS::Default)
263     return AddrSpace;
264 
265   return DefaultGlobalAS;
266 }
267 
setTargetAttributes(const Decl * D,llvm::GlobalValue * GV,CodeGen::CodeGenModule & M) const268 void SPIRVTargetCodeGenInfo::setTargetAttributes(
269     const Decl *D, llvm::GlobalValue *GV, CodeGen::CodeGenModule &M) const {
270   if (!M.getLangOpts().HIP ||
271       M.getTarget().getTriple().getVendor() != llvm::Triple::AMD)
272     return;
273   if (GV->isDeclaration())
274     return;
275 
276   auto F = dyn_cast<llvm::Function>(GV);
277   if (!F)
278     return;
279 
280   auto FD = dyn_cast_or_null<FunctionDecl>(D);
281   if (!FD)
282     return;
283   if (!FD->hasAttr<CUDAGlobalAttr>())
284     return;
285 
286   unsigned N = M.getLangOpts().GPUMaxThreadsPerBlock;
287   if (auto FlatWGS = FD->getAttr<AMDGPUFlatWorkGroupSizeAttr>())
288     N = FlatWGS->getMax()->EvaluateKnownConstInt(M.getContext()).getExtValue();
289 
290   // We encode the maximum flat WG size in the first component of the 3D
291   // max_work_group_size attribute, which will get reverse translated into the
292   // original AMDGPU attribute when targeting AMDGPU.
293   auto Int32Ty = llvm::IntegerType::getInt32Ty(M.getLLVMContext());
294   llvm::Metadata *AttrMDArgs[] = {
295       llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(Int32Ty, N)),
296       llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(Int32Ty, 1)),
297       llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(Int32Ty, 1))};
298 
299   F->setMetadata("max_work_group_size",
300                  llvm::MDNode::get(M.getLLVMContext(), AttrMDArgs));
301 }
302 
303 llvm::SyncScope::ID
getLLVMSyncScopeID(const LangOptions &,SyncScope Scope,llvm::AtomicOrdering,llvm::LLVMContext & Ctx) const304 SPIRVTargetCodeGenInfo::getLLVMSyncScopeID(const LangOptions &, SyncScope Scope,
305                                            llvm::AtomicOrdering,
306                                            llvm::LLVMContext &Ctx) const {
307   return Ctx.getOrInsertSyncScopeID(mapClangSyncScopeToLLVM(Scope));
308 }
309 
310 /// Construct a SPIR-V target extension type for the given OpenCL image type.
getSPIRVImageType(llvm::LLVMContext & Ctx,StringRef BaseType,StringRef OpenCLName,unsigned AccessQualifier)311 static llvm::Type *getSPIRVImageType(llvm::LLVMContext &Ctx, StringRef BaseType,
312                                      StringRef OpenCLName,
313                                      unsigned AccessQualifier) {
314   // These parameters compare to the operands of OpTypeImage (see
315   // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeImage
316   // for more details). The first 6 integer parameters all default to 0, and
317   // will be changed to 1 only for the image type(s) that set the parameter to
318   // one. The 7th integer parameter is the access qualifier, which is tacked on
319   // at the end.
320   SmallVector<unsigned, 7> IntParams = {0, 0, 0, 0, 0, 0};
321 
322   // Choose the dimension of the image--this corresponds to the Dim enum in
323   // SPIR-V (first integer parameter of OpTypeImage).
324   if (OpenCLName.starts_with("image2d"))
325     IntParams[0] = 1;
326   else if (OpenCLName.starts_with("image3d"))
327     IntParams[0] = 2;
328   else if (OpenCLName == "image1d_buffer")
329     IntParams[0] = 5; // Buffer
330   else
331     assert(OpenCLName.starts_with("image1d") && "Unknown image type");
332 
333   // Set the other integer parameters of OpTypeImage if necessary. Note that the
334   // OpenCL image types don't provide any information for the Sampled or
335   // Image Format parameters.
336   if (OpenCLName.contains("_depth"))
337     IntParams[1] = 1;
338   if (OpenCLName.contains("_array"))
339     IntParams[2] = 1;
340   if (OpenCLName.contains("_msaa"))
341     IntParams[3] = 1;
342 
343   // Access qualifier
344   IntParams.push_back(AccessQualifier);
345 
346   return llvm::TargetExtType::get(Ctx, BaseType, {llvm::Type::getVoidTy(Ctx)},
347                                   IntParams);
348 }
349 
getOpenCLType(CodeGenModule & CGM,const Type * Ty) const350 llvm::Type *CommonSPIRTargetCodeGenInfo::getOpenCLType(CodeGenModule &CGM,
351                                                        const Type *Ty) const {
352   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
353   if (auto *PipeTy = dyn_cast<PipeType>(Ty))
354     return llvm::TargetExtType::get(Ctx, "spirv.Pipe", {},
355                                     {!PipeTy->isReadOnly()});
356   if (auto *BuiltinTy = dyn_cast<BuiltinType>(Ty)) {
357     enum AccessQualifier : unsigned { AQ_ro = 0, AQ_wo = 1, AQ_rw = 2 };
358     switch (BuiltinTy->getKind()) {
359 #define IMAGE_TYPE(ImgType, Id, SingletonId, Access, Suffix)                   \
360     case BuiltinType::Id:                                                      \
361       return getSPIRVImageType(Ctx, "spirv.Image", #ImgType, AQ_##Suffix);
362 #include "clang/Basic/OpenCLImageTypes.def"
363     case BuiltinType::OCLSampler:
364       return llvm::TargetExtType::get(Ctx, "spirv.Sampler");
365     case BuiltinType::OCLEvent:
366       return llvm::TargetExtType::get(Ctx, "spirv.Event");
367     case BuiltinType::OCLClkEvent:
368       return llvm::TargetExtType::get(Ctx, "spirv.DeviceEvent");
369     case BuiltinType::OCLQueue:
370       return llvm::TargetExtType::get(Ctx, "spirv.Queue");
371     case BuiltinType::OCLReserveID:
372       return llvm::TargetExtType::get(Ctx, "spirv.ReserveId");
373 #define INTEL_SUBGROUP_AVC_TYPE(Name, Id)                                      \
374     case BuiltinType::OCLIntelSubgroupAVC##Id:                                 \
375       return llvm::TargetExtType::get(Ctx, "spirv.Avc" #Id "INTEL");
376 #include "clang/Basic/OpenCLExtensionTypes.def"
377     default:
378       return nullptr;
379     }
380   }
381 
382   return nullptr;
383 }
384 
385 // Gets a spirv.IntegralConstant or spirv.Literal. If IntegralType is present,
386 // returns an IntegralConstant, otherwise returns a Literal.
getInlineSpirvConstant(CodeGenModule & CGM,llvm::Type * IntegralType,llvm::APInt Value)387 static llvm::Type *getInlineSpirvConstant(CodeGenModule &CGM,
388                                           llvm::Type *IntegralType,
389                                           llvm::APInt Value) {
390   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
391 
392   // Convert the APInt value to an array of uint32_t words
393   llvm::SmallVector<uint32_t> Words;
394 
395   while (Value.ugt(0)) {
396     uint32_t Word = Value.trunc(32).getZExtValue();
397     Value.lshrInPlace(32);
398 
399     Words.push_back(Word);
400   }
401   if (Words.size() == 0)
402     Words.push_back(0);
403 
404   if (IntegralType)
405     return llvm::TargetExtType::get(Ctx, "spirv.IntegralConstant",
406                                     {IntegralType}, Words);
407   return llvm::TargetExtType::get(Ctx, "spirv.Literal", {}, Words);
408 }
409 
getInlineSpirvType(CodeGenModule & CGM,const HLSLInlineSpirvType * SpirvType)410 static llvm::Type *getInlineSpirvType(CodeGenModule &CGM,
411                                       const HLSLInlineSpirvType *SpirvType) {
412   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
413 
414   llvm::SmallVector<llvm::Type *> Operands;
415 
416   for (auto &Operand : SpirvType->getOperands()) {
417     using SpirvOperandKind = SpirvOperand::SpirvOperandKind;
418 
419     llvm::Type *Result = nullptr;
420     switch (Operand.getKind()) {
421     case SpirvOperandKind::ConstantId: {
422       llvm::Type *IntegralType =
423           CGM.getTypes().ConvertType(Operand.getResultType());
424 
425       Result = getInlineSpirvConstant(CGM, IntegralType, Operand.getValue());
426       break;
427     }
428     case SpirvOperandKind::Literal: {
429       Result = getInlineSpirvConstant(CGM, nullptr, Operand.getValue());
430       break;
431     }
432     case SpirvOperandKind::TypeId: {
433       QualType TypeOperand = Operand.getResultType();
434       if (auto *RT = TypeOperand->getAs<RecordType>()) {
435         auto *RD = RT->getDecl();
436         assert(RD->isCompleteDefinition() &&
437                "Type completion should have been required in Sema");
438 
439         const FieldDecl *HandleField = RD->findFirstNamedDataMember();
440         if (HandleField) {
441           QualType ResourceType = HandleField->getType();
442           if (ResourceType->getAs<HLSLAttributedResourceType>()) {
443             TypeOperand = ResourceType;
444           }
445         }
446       }
447       Result = CGM.getTypes().ConvertType(TypeOperand);
448       break;
449     }
450     default:
451       llvm_unreachable("HLSLInlineSpirvType had invalid operand!");
452       break;
453     }
454 
455     assert(Result);
456     Operands.push_back(Result);
457   }
458 
459   return llvm::TargetExtType::get(Ctx, "spirv.Type", Operands,
460                                   {SpirvType->getOpcode(), SpirvType->getSize(),
461                                    SpirvType->getAlignment()});
462 }
463 
getHLSLType(CodeGenModule & CGM,const Type * Ty,const SmallVector<int32_t> * Packoffsets) const464 llvm::Type *CommonSPIRTargetCodeGenInfo::getHLSLType(
465     CodeGenModule &CGM, const Type *Ty,
466     const SmallVector<int32_t> *Packoffsets) const {
467   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
468 
469   if (auto *SpirvType = dyn_cast<HLSLInlineSpirvType>(Ty))
470     return getInlineSpirvType(CGM, SpirvType);
471 
472   auto *ResType = dyn_cast<HLSLAttributedResourceType>(Ty);
473   if (!ResType)
474     return nullptr;
475 
476   const HLSLAttributedResourceType::Attributes &ResAttrs = ResType->getAttrs();
477   switch (ResAttrs.ResourceClass) {
478   case llvm::dxil::ResourceClass::UAV:
479   case llvm::dxil::ResourceClass::SRV: {
480     // TypedBuffer and RawBuffer both need element type
481     QualType ContainedTy = ResType->getContainedType();
482     if (ContainedTy.isNull())
483       return nullptr;
484 
485     assert(!ResAttrs.IsROV &&
486            "Rasterizer order views not implemented for SPIR-V yet");
487 
488     if (!ResAttrs.RawBuffer) {
489       // convert element type
490       return getSPIRVImageTypeFromHLSLResource(ResAttrs, ContainedTy, CGM);
491     }
492 
493     llvm::Type *ElemType = CGM.getTypes().ConvertTypeForMem(ContainedTy);
494     llvm::ArrayType *RuntimeArrayType = llvm::ArrayType::get(ElemType, 0);
495     uint32_t StorageClass = /* StorageBuffer storage class */ 12;
496     bool IsWritable = ResAttrs.ResourceClass == llvm::dxil::ResourceClass::UAV;
497     return llvm::TargetExtType::get(Ctx, "spirv.VulkanBuffer",
498                                     {RuntimeArrayType},
499                                     {StorageClass, IsWritable});
500   }
501   case llvm::dxil::ResourceClass::CBuffer: {
502     QualType ContainedTy = ResType->getContainedType();
503     if (ContainedTy.isNull() || !ContainedTy->isStructureType())
504       return nullptr;
505 
506     llvm::Type *BufferLayoutTy =
507         HLSLBufferLayoutBuilder(CGM, "spirv.Layout")
508             .createLayoutType(ContainedTy->getAsStructureType(), Packoffsets);
509     uint32_t StorageClass = /* Uniform storage class */ 2;
510     return llvm::TargetExtType::get(Ctx, "spirv.VulkanBuffer", {BufferLayoutTy},
511                                     {StorageClass, false});
512     break;
513   }
514   case llvm::dxil::ResourceClass::Sampler:
515     return llvm::TargetExtType::get(Ctx, "spirv.Sampler");
516   }
517   return nullptr;
518 }
519 
getSPIRVImageTypeFromHLSLResource(const HLSLAttributedResourceType::Attributes & attributes,QualType Ty,CodeGenModule & CGM) const520 llvm::Type *CommonSPIRTargetCodeGenInfo::getSPIRVImageTypeFromHLSLResource(
521     const HLSLAttributedResourceType::Attributes &attributes, QualType Ty,
522     CodeGenModule &CGM) const {
523   llvm::LLVMContext &Ctx = CGM.getLLVMContext();
524 
525   Ty = Ty->getCanonicalTypeUnqualified();
526   if (const VectorType *V = dyn_cast<VectorType>(Ty))
527     Ty = V->getElementType();
528   assert(!Ty->isVectorType() && "We still have a vector type.");
529 
530   llvm::Type *SampledType = CGM.getTypes().ConvertTypeForMem(Ty);
531 
532   assert((SampledType->isIntegerTy() || SampledType->isFloatingPointTy()) &&
533          "The element type for a SPIR-V resource must be a scalar integer or "
534          "floating point type.");
535 
536   // These parameters correspond to the operands to the OpTypeImage SPIR-V
537   // instruction. See
538   // https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpTypeImage.
539   SmallVector<unsigned, 6> IntParams(6, 0);
540 
541   const char *Name =
542       Ty->isSignedIntegerType() ? "spirv.SignedImage" : "spirv.Image";
543 
544   // Dim
545   // For now we assume everything is a buffer.
546   IntParams[0] = 5;
547 
548   // Depth
549   // HLSL does not indicate if it is a depth texture or not, so we use unknown.
550   IntParams[1] = 2;
551 
552   // Arrayed
553   IntParams[2] = 0;
554 
555   // MS
556   IntParams[3] = 0;
557 
558   // Sampled
559   IntParams[4] =
560       attributes.ResourceClass == llvm::dxil::ResourceClass::UAV ? 2 : 1;
561 
562   // Image format.
563   // Setting to unknown for now.
564   IntParams[5] = 0;
565 
566   llvm::TargetExtType *ImageType =
567       llvm::TargetExtType::get(Ctx, Name, {SampledType}, IntParams);
568   return ImageType;
569 }
570 
571 std::unique_ptr<TargetCodeGenInfo>
createCommonSPIRTargetCodeGenInfo(CodeGenModule & CGM)572 CodeGen::createCommonSPIRTargetCodeGenInfo(CodeGenModule &CGM) {
573   return std::make_unique<CommonSPIRTargetCodeGenInfo>(CGM.getTypes());
574 }
575 
576 std::unique_ptr<TargetCodeGenInfo>
createSPIRVTargetCodeGenInfo(CodeGenModule & CGM)577 CodeGen::createSPIRVTargetCodeGenInfo(CodeGenModule &CGM) {
578   return std::make_unique<SPIRVTargetCodeGenInfo>(CGM.getTypes());
579 }
580