xref: /freebsd/contrib/llvm-project/clang/lib/Sema/SemaSPIRV.cpp (revision 9c77fb6aaa366cbabc80ee1b834bcfe4df135491)
1 //===- SemaSPIRV.cpp - Semantic Analysis for SPIRV constructs--------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This implements Semantic Analysis for SPIRV constructs.
9 //===----------------------------------------------------------------------===//
10 
11 #include "clang/Sema/SemaSPIRV.h"
12 #include "clang/Basic/TargetBuiltins.h"
13 #include "clang/Basic/TargetInfo.h"
14 #include "clang/Sema/Sema.h"
15 
16 // SPIR-V enumerants. Enums have only the required entries, see SPIR-V specs for
17 // values.
18 // FIXME: either use the SPIRV-Headers or generate a custom header using the
19 // grammar (like done with MLIR).
20 namespace spirv {
21 enum class StorageClass : int {
22   Workgroup = 4,
23   CrossWorkgroup = 5,
24   Function = 7
25 };
26 }
27 
28 namespace clang {
29 
30 SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {}
31 
32 static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
33   assert(TheCall->getNumArgs() > 1);
34   QualType ArgTy0 = TheCall->getArg(0)->getType();
35 
36   for (unsigned I = 1, N = TheCall->getNumArgs(); I < N; ++I) {
37     if (!S->getASTContext().hasSameUnqualifiedType(
38             ArgTy0, TheCall->getArg(I)->getType())) {
39       S->Diag(TheCall->getBeginLoc(), diag::err_vec_builtin_incompatible_vector)
40           << TheCall->getDirectCallee() << /*useAllTerminology*/ true
41           << SourceRange(TheCall->getArg(0)->getBeginLoc(),
42                          TheCall->getArg(N - 1)->getEndLoc());
43       return true;
44     }
45   }
46   return false;
47 }
48 
49 static std::optional<int>
50 processConstant32BitIntArgument(Sema &SemaRef, CallExpr *Call, int Argument) {
51   ExprResult Arg =
52       SemaRef.DefaultFunctionArrayLvalueConversion(Call->getArg(Argument));
53   if (Arg.isInvalid())
54     return true;
55   Call->setArg(Argument, Arg.get());
56 
57   const Expr *IntArg = Arg.get();
58   SmallVector<PartialDiagnosticAt, 8> Notes;
59   Expr::EvalResult Eval;
60   Eval.Diag = &Notes;
61   if ((!IntArg->EvaluateAsConstantExpr(Eval, SemaRef.getASTContext())) ||
62       !Eval.Val.isInt() || Eval.Val.getInt().getBitWidth() > 32) {
63     SemaRef.Diag(IntArg->getBeginLoc(), diag::err_spirv_enum_not_int)
64         << 0 << IntArg->getSourceRange();
65     for (const PartialDiagnosticAt &PDiag : Notes)
66       SemaRef.Diag(PDiag.first, PDiag.second);
67     return true;
68   }
69   return {Eval.Val.getInt().getZExtValue()};
70 }
71 
72 static bool checkGenericCastToPtr(Sema &SemaRef, CallExpr *Call) {
73   if (SemaRef.checkArgCount(Call, 2))
74     return true;
75 
76   {
77     ExprResult Arg =
78         SemaRef.DefaultFunctionArrayLvalueConversion(Call->getArg(0));
79     if (Arg.isInvalid())
80       return true;
81     Call->setArg(0, Arg.get());
82 
83     QualType Ty = Arg.get()->getType();
84     const auto *PtrTy = Ty->getAs<PointerType>();
85     auto AddressSpaceNotInGeneric = [&](LangAS AS) {
86       if (SemaRef.LangOpts.OpenCL)
87         return AS != LangAS::opencl_generic;
88       return AS != LangAS::Default;
89     };
90     if (!PtrTy ||
91         AddressSpaceNotInGeneric(PtrTy->getPointeeType().getAddressSpace())) {
92       SemaRef.Diag(Arg.get()->getBeginLoc(),
93                    diag::err_spirv_builtin_generic_cast_invalid_arg)
94           << Call->getSourceRange();
95       return true;
96     }
97   }
98 
99   spirv::StorageClass StorageClass;
100   if (std::optional<int> SCInt =
101           processConstant32BitIntArgument(SemaRef, Call, 1);
102       SCInt.has_value()) {
103     StorageClass = static_cast<spirv::StorageClass>(SCInt.value());
104     if (StorageClass != spirv::StorageClass::CrossWorkgroup &&
105         StorageClass != spirv::StorageClass::Workgroup &&
106         StorageClass != spirv::StorageClass::Function) {
107       SemaRef.Diag(Call->getArg(1)->getBeginLoc(),
108                    diag::err_spirv_enum_not_valid)
109           << 0 << Call->getArg(1)->getSourceRange();
110       return true;
111     }
112   } else {
113     return true;
114   }
115   auto RT = Call->getArg(0)->getType();
116   RT = RT->getPointeeType();
117   auto Qual = RT.getQualifiers();
118   LangAS AddrSpace;
119   switch (StorageClass) {
120   case spirv::StorageClass::CrossWorkgroup:
121     AddrSpace =
122         SemaRef.LangOpts.isSYCL() ? LangAS::sycl_global : LangAS::opencl_global;
123     break;
124   case spirv::StorageClass::Workgroup:
125     AddrSpace =
126         SemaRef.LangOpts.isSYCL() ? LangAS::sycl_local : LangAS::opencl_local;
127     break;
128   case spirv::StorageClass::Function:
129     AddrSpace = SemaRef.LangOpts.isSYCL() ? LangAS::sycl_private
130                                           : LangAS::opencl_private;
131     break;
132   }
133   Qual.setAddressSpace(AddrSpace);
134   Call->setType(SemaRef.getASTContext().getPointerType(
135       SemaRef.getASTContext().getQualifiedType(RT.getUnqualifiedType(), Qual)));
136 
137   return false;
138 }
139 
140 bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
141                                               unsigned BuiltinID,
142                                               CallExpr *TheCall) {
143   if (BuiltinID >= SPIRV::FirstVKBuiltin && BuiltinID <= SPIRV::LastVKBuiltin &&
144       TI.getTriple().getArch() != llvm::Triple::spirv) {
145     SemaRef.Diag(TheCall->getBeginLoc(), diag::err_spirv_invalid_target) << 0;
146     return true;
147   }
148   if (BuiltinID >= SPIRV::FirstCLBuiltin && BuiltinID <= SPIRV::LastTSBuiltin &&
149       TI.getTriple().getArch() != llvm::Triple::spirv32 &&
150       TI.getTriple().getArch() != llvm::Triple::spirv64) {
151     SemaRef.Diag(TheCall->getBeginLoc(), diag::err_spirv_invalid_target) << 1;
152     return true;
153   }
154 
155   switch (BuiltinID) {
156   case SPIRV::BI__builtin_spirv_distance: {
157     if (SemaRef.checkArgCount(TheCall, 2))
158       return true;
159 
160     ExprResult A = TheCall->getArg(0);
161     QualType ArgTyA = A.get()->getType();
162     auto *VTyA = ArgTyA->getAs<VectorType>();
163     if (VTyA == nullptr) {
164       SemaRef.Diag(A.get()->getBeginLoc(),
165                    diag::err_typecheck_convert_incompatible)
166           << ArgTyA
167           << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
168           << 0 << 0;
169       return true;
170     }
171 
172     ExprResult B = TheCall->getArg(1);
173     QualType ArgTyB = B.get()->getType();
174     auto *VTyB = ArgTyB->getAs<VectorType>();
175     if (VTyB == nullptr) {
176       SemaRef.Diag(A.get()->getBeginLoc(),
177                    diag::err_typecheck_convert_incompatible)
178           << ArgTyB
179           << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
180           << 0 << 0;
181       return true;
182     }
183 
184     QualType RetTy = VTyA->getElementType();
185     TheCall->setType(RetTy);
186     break;
187   }
188   case SPIRV::BI__builtin_spirv_length: {
189     if (SemaRef.checkArgCount(TheCall, 1))
190       return true;
191     ExprResult A = TheCall->getArg(0);
192     QualType ArgTyA = A.get()->getType();
193     auto *VTy = ArgTyA->getAs<VectorType>();
194     if (VTy == nullptr) {
195       SemaRef.Diag(A.get()->getBeginLoc(),
196                    diag::err_typecheck_convert_incompatible)
197           << ArgTyA
198           << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
199           << 0 << 0;
200       return true;
201     }
202     QualType RetTy = VTy->getElementType();
203     TheCall->setType(RetTy);
204     break;
205   }
206   case SPIRV::BI__builtin_spirv_reflect: {
207     if (SemaRef.checkArgCount(TheCall, 2))
208       return true;
209 
210     ExprResult A = TheCall->getArg(0);
211     QualType ArgTyA = A.get()->getType();
212     auto *VTyA = ArgTyA->getAs<VectorType>();
213     if (VTyA == nullptr) {
214       SemaRef.Diag(A.get()->getBeginLoc(),
215                    diag::err_typecheck_convert_incompatible)
216           << ArgTyA
217           << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
218           << 0 << 0;
219       return true;
220     }
221 
222     ExprResult B = TheCall->getArg(1);
223     QualType ArgTyB = B.get()->getType();
224     auto *VTyB = ArgTyB->getAs<VectorType>();
225     if (VTyB == nullptr) {
226       SemaRef.Diag(A.get()->getBeginLoc(),
227                    diag::err_typecheck_convert_incompatible)
228           << ArgTyB
229           << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
230           << 0 << 0;
231       return true;
232     }
233 
234     QualType RetTy = ArgTyA;
235     TheCall->setType(RetTy);
236     break;
237   }
238   case SPIRV::BI__builtin_spirv_smoothstep: {
239     if (SemaRef.checkArgCount(TheCall, 3))
240       return true;
241 
242     // Check if first argument has floating representation
243     ExprResult A = TheCall->getArg(0);
244     QualType ArgTyA = A.get()->getType();
245     if (!ArgTyA->hasFloatingRepresentation()) {
246       SemaRef.Diag(A.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
247           << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
248           << /* fp */ 1 << ArgTyA;
249       return true;
250     }
251 
252     if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
253       return true;
254 
255     QualType RetTy = ArgTyA;
256     TheCall->setType(RetTy);
257     break;
258   }
259   case SPIRV::BI__builtin_spirv_faceforward: {
260     if (SemaRef.checkArgCount(TheCall, 3))
261       return true;
262 
263     // Check if first argument has floating representation
264     ExprResult A = TheCall->getArg(0);
265     QualType ArgTyA = A.get()->getType();
266     if (!ArgTyA->hasFloatingRepresentation()) {
267       SemaRef.Diag(A.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
268           << /* ordinal */ 1 << /* scalar or vector */ 5 << /* no int */ 0
269           << /* fp */ 1 << ArgTyA;
270       return true;
271     }
272 
273     if (CheckAllArgsHaveSameType(&SemaRef, TheCall))
274       return true;
275 
276     QualType RetTy = ArgTyA;
277     TheCall->setType(RetTy);
278     break;
279   }
280   case SPIRV::BI__builtin_spirv_generic_cast_to_ptr_explicit: {
281     return checkGenericCastToPtr(SemaRef, TheCall);
282   }
283   }
284   return false;
285 }
286 } // namespace clang
287