1 //===- SemaSPIRV.cpp - Semantic Analysis for SPIRV constructs--------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This implements Semantic Analysis for SPIRV constructs. 9 //===----------------------------------------------------------------------===// 10 11 #include "clang/Sema/SemaSPIRV.h" 12 #include "clang/Basic/TargetBuiltins.h" 13 #include "clang/Basic/TargetInfo.h" 14 #include "clang/Sema/Sema.h" 15 16 // SPIR-V enumerants. Enums have only the required entries, see SPIR-V specs for 17 // values. 18 // FIXME: either use the SPIRV-Headers or generate a custom header using the 19 // grammar (like done with MLIR). 20 namespace spirv { 21 enum class StorageClass : int { 22 Workgroup = 4, 23 CrossWorkgroup = 5, 24 Function = 7 25 }; 26 } 27 28 namespace clang { 29 30 SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {} 31 32 static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) { 33 assert(TheCall->getNumArgs() > 1); 34 QualType ArgTy0 = TheCall->getArg(0)->getType(); 35 36 for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) { 37 if (!S->getASTContext().hasSameUnqualifiedType( 38 ArgTy0, TheCall->getArg(I)->getType())) { 39 S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector) 40 << TheCall->getDirectCallee() << /*useAllTerminology*/ true 41 << SourceRange(TheCall->getArg(0)->getBeginLoc(), 42 TheCall->getArg(N - 1)->getEndLoc()); 43 return true; 44 } 45 } 46 return false; 47 } 48 49 static std::optional<int> 50 processConstant32BitIntArgument(Sema &SemaRef, CallExpr *Call, int Argument) { 51 ExprResult Arg = 52 SemaRef.DefaultFunctionArrayLvalueConversion(Call->getArg(Argument)); 53 if (Arg.isInvalid()) 54 return true; 55 Call->setArg(Argument, Arg.get()); 56 57 const Expr *IntArg = Arg.get(); 58 SmallVector<PartialDiagnosticAt, 8> Notes; 59 Expr::EvalResult Eval; 60 Eval.Diag = &Notes; 61 if ((!IntArg->EvaluateAsConstantExpr(Eval, SemaRef.getASTContext())) || 62 !Eval.Val.isInt() || Eval.Val.getInt().getBitWidth() > 32) { 63 SemaRef.Diag(IntArg->getBeginLoc(), diag::err_spirv_enum_not_int) 64 << 0 << IntArg->getSourceRange(); 65 for (const PartialDiagnosticAt &PDiag : Notes) 66 SemaRef.Diag(PDiag.first, PDiag.second); 67 return true; 68 } 69 return {Eval.Val.getInt().getZExtValue()}; 70 } 71 72 static bool checkGenericCastToPtr(Sema &SemaRef, CallExpr *Call) { 73 if (SemaRef.checkArgCount(Call, 2)) 74 return true; 75 76 { 77 ExprResult Arg = 78 SemaRef.DefaultFunctionArrayLvalueConversion(Call->getArg(0)); 79 if (Arg.isInvalid()) 80 return true; 81 Call->setArg(0, Arg.get()); 82 83 QualType Ty = Arg.get()->getType(); 84 const auto *PtrTy = Ty->getAs<PointerType>(); 85 auto AddressSpaceNotInGeneric = [&](LangAS AS) { 86 if (SemaRef.LangOpts.OpenCL) 87 return AS != LangAS::opencl_generic; 88 return AS != LangAS::Default; 89 }; 90 if (!PtrTy || 91 AddressSpaceNotInGeneric(PtrTy->getPointeeType().getAddressSpace())) { 92 SemaRef.Diag(Arg.get()->getBeginLoc(), 93 diag::err_spirv_builtin_generic_cast_invalid_arg) 94 << Call->getSourceRange(); 95 return true; 96 } 97 } 98 99 spirv::StorageClass StorageClass; 100 if (std::optional<int> SCInt = 101 processConstant32BitIntArgument(SemaRef, Call, 1); 102 SCInt.has_value()) { 103 StorageClass = static_cast<spirv::StorageClass>(SCInt.value()); 104 if (StorageClass != spirv::StorageClass::CrossWorkgroup && 105 StorageClass != spirv::StorageClass::Workgroup && 106 StorageClass != spirv::StorageClass::Function) { 107 SemaRef.Diag(Call->getArg(1)->getBeginLoc(), 108 diag::err_spirv_enum_not_valid) 109 << 0 << Call->getArg(1)->getSourceRange(); 110 return true; 111 } 112 } else { 113 return true; 114 } 115 auto RT = Call->getArg(0)->getType(); 116 RT = RT->getPointeeType(); 117 auto Qual = RT.getQualifiers(); 118 LangAS AddrSpace; 119 switch (StorageClass) { 120 case spirv::StorageClass::CrossWorkgroup: 121 AddrSpace = 122 SemaRef.LangOpts.isSYCL() ? LangAS::sycl_global : LangAS::opencl_global; 123 break; 124 case spirv::StorageClass::Workgroup: 125 AddrSpace = 126 SemaRef.LangOpts.isSYCL() ? LangAS::sycl_local : LangAS::opencl_local; 127 break; 128 case spirv::StorageClass::Function: 129 AddrSpace = SemaRef.LangOpts.isSYCL() ? LangAS::sycl_private 130 : LangAS::opencl_private; 131 break; 132 } 133 Qual.setAddressSpace(AddrSpace); 134 Call->setType(SemaRef.getASTContext().getPointerType( 135 SemaRef.getASTContext().getQualifiedType(RT.getUnqualifiedType(), Qual))); 136 137 return false; 138 } 139 140 bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI, 141 unsigned BuiltinID, 142 CallExpr *TheCall) { 143 if (BuiltinID >= SPIRV::FirstVKBuiltin && BuiltinID <= SPIRV::LastVKBuiltin && 144 TI.getTriple().getArch() != llvm::Triple::spirv) { 145 SemaRef.Diag(TheCall->getBeginLoc(), diag::err_spirv_invalid_target) << 0; 146 return true; 147 } 148 if (BuiltinID >= SPIRV::FirstCLBuiltin && BuiltinID <= SPIRV::LastTSBuiltin && 149 TI.getTriple().getArch() != llvm::Triple::spirv32 && 150 TI.getTriple().getArch() != llvm::Triple::spirv64) { 151 SemaRef.Diag(TheCall->getBeginLoc(), diag::err_spirv_invalid_target) << 1; 152 return true; 153 } 154 155 switch (BuiltinID) { 156 case SPIRV::BI__builtin_spirv_distance: { 157 if (SemaRef.checkArgCount(TheCall, 2)) 158 return true; 159 160 ExprResult A = TheCall->getArg(0); 161 QualType ArgTyA = A.get()->getType(); 162 auto *VTyA = ArgTyA->getAs<VectorType>(); 163 if (VTyA == nullptr) { 164 SemaRef.Diag(A.get()->getBeginLoc(), 165 diag::err_typecheck_convert_incompatible) 166 << ArgTyA 167 << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 168 << 0 << 0; 169 return true; 170 } 171 172 ExprResult B = TheCall->getArg(1); 173 QualType ArgTyB = B.get()->getType(); 174 auto *VTyB = ArgTyB->getAs<VectorType>(); 175 if (VTyB == nullptr) { 176 SemaRef.Diag(A.get()->getBeginLoc(), 177 diag::err_typecheck_convert_incompatible) 178 << ArgTyB 179 << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1 180 << 0 << 0; 181 return true; 182 } 183 184 QualType RetTy = VTyA->getElementType(); 185 TheCall->setType(RetTy); 186 break; 187 } 188 case SPIRV::BI__builtin_spirv_length: { 189 if (SemaRef.checkArgCount(TheCall, 1)) 190 return true; 191 ExprResult A = TheCall->getArg(0); 192 QualType ArgTyA = A.get()->getType(); 193 auto *VTy = ArgTyA->getAs<VectorType>(); 194 if (VTy == nullptr) { 195 SemaRef.Diag(A.get()->getBeginLoc(), 196 diag::err_typecheck_convert_incompatible) 197 << ArgTyA 198 << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 199 << 0 << 0; 200 return true; 201 } 202 QualType RetTy = VTy->getElementType(); 203 TheCall->setType(RetTy); 204 break; 205 } 206 case SPIRV::BI__builtin_spirv_reflect: { 207 if (SemaRef.checkArgCount(TheCall, 2)) 208 return true; 209 210 ExprResult A = TheCall->getArg(0); 211 QualType ArgTyA = A.get()->getType(); 212 auto *VTyA = ArgTyA->getAs<VectorType>(); 213 if (VTyA == nullptr) { 214 SemaRef.Diag(A.get()->getBeginLoc(), 215 diag::err_typecheck_convert_incompatible) 216 << ArgTyA 217 << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 218 << 0 << 0; 219 return true; 220 } 221 222 ExprResult B = TheCall->getArg(1); 223 QualType ArgTyB = B.get()->getType(); 224 auto *VTyB = ArgTyB->getAs<VectorType>(); 225 if (VTyB == nullptr) { 226 SemaRef.Diag(A.get()->getBeginLoc(), 227 diag::err_typecheck_convert_incompatible) 228 << ArgTyB 229 << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1 230 << 0 << 0; 231 return true; 232 } 233 234 QualType RetTy = ArgTyA; 235 TheCall->setType(RetTy); 236 break; 237 } 238 case SPIRV::BI__builtin_spirv_smoothstep: { 239 if (SemaRef.checkArgCount(TheCall, 3)) 240 return true; 241 242 // Check if first argument has floating representation 243 ExprResult A = TheCall->getArg(0); 244 QualType ArgTyA = A.get()->getType(); 245 if (!ArgTyA->hasFloatingRepresentation()) { 246 SemaRef.Diag(A.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type) 247 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0 248 << /* fp */ 1 << ArgTyA; 249 return true; 250 } 251 252 if (CheckAllArgsHaveSameType(&SemaRef, TheCall)) 253 return true; 254 255 QualType RetTy = ArgTyA; 256 TheCall->setType(RetTy); 257 break; 258 } 259 case SPIRV::BI__builtin_spirv_faceforward: { 260 if (SemaRef.checkArgCount(TheCall, 3)) 261 return true; 262 263 // Check if first argument has floating representation 264 ExprResult A = TheCall->getArg(0); 265 QualType ArgTyA = A.get()->getType(); 266 if (!ArgTyA->hasFloatingRepresentation()) { 267 SemaRef.Diag(A.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type) 268 << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0 269 << /* fp */ 1 << ArgTyA; 270 return true; 271 } 272 273 if (CheckAllArgsHaveSameType(&SemaRef, TheCall)) 274 return true; 275 276 QualType RetTy = ArgTyA; 277 TheCall->setType(RetTy); 278 break; 279 } 280 case SPIRV::BI__builtin_spirv_generic_cast_to_ptr_explicit: { 281 return checkGenericCastToPtr(SemaRef, TheCall); 282 } 283 } 284 return false; 285 } 286 } // namespace clang 287