xref: /freebsd/contrib/llvm-project/llvm/lib/Target/DirectX/DXILOpBuilder.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains class to help build DXIL op functions.
10 //===----------------------------------------------------------------------===//
11 
12 #include "DXILOpBuilder.h"
13 #include "DXILConstants.h"
14 #include "llvm/IR/IRBuilder.h"
15 #include "llvm/IR/Module.h"
16 #include "llvm/Support/DXILABI.h"
17 #include "llvm/Support/ErrorHandling.h"
18 
19 using namespace llvm;
20 using namespace llvm::dxil;
21 
22 constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
23 
24 namespace {
25 
26 enum OverloadKind : uint16_t {
27   VOID = 1,
28   HALF = 1 << 1,
29   FLOAT = 1 << 2,
30   DOUBLE = 1 << 3,
31   I1 = 1 << 4,
32   I8 = 1 << 5,
33   I16 = 1 << 6,
34   I32 = 1 << 7,
35   I64 = 1 << 8,
36   UserDefineType = 1 << 9,
37   ObjectType = 1 << 10,
38 };
39 
40 } // namespace
41 
getOverloadTypeName(OverloadKind Kind)42 static const char *getOverloadTypeName(OverloadKind Kind) {
43   switch (Kind) {
44   case OverloadKind::HALF:
45     return "f16";
46   case OverloadKind::FLOAT:
47     return "f32";
48   case OverloadKind::DOUBLE:
49     return "f64";
50   case OverloadKind::I1:
51     return "i1";
52   case OverloadKind::I8:
53     return "i8";
54   case OverloadKind::I16:
55     return "i16";
56   case OverloadKind::I32:
57     return "i32";
58   case OverloadKind::I64:
59     return "i64";
60   case OverloadKind::VOID:
61   case OverloadKind::ObjectType:
62   case OverloadKind::UserDefineType:
63     break;
64   }
65   llvm_unreachable("invalid overload type for name");
66   return "void";
67 }
68 
getOverloadKind(Type * Ty)69 static OverloadKind getOverloadKind(Type *Ty) {
70   Type::TypeID T = Ty->getTypeID();
71   switch (T) {
72   case Type::VoidTyID:
73     return OverloadKind::VOID;
74   case Type::HalfTyID:
75     return OverloadKind::HALF;
76   case Type::FloatTyID:
77     return OverloadKind::FLOAT;
78   case Type::DoubleTyID:
79     return OverloadKind::DOUBLE;
80   case Type::IntegerTyID: {
81     IntegerType *ITy = cast<IntegerType>(Ty);
82     unsigned Bits = ITy->getBitWidth();
83     switch (Bits) {
84     case 1:
85       return OverloadKind::I1;
86     case 8:
87       return OverloadKind::I8;
88     case 16:
89       return OverloadKind::I16;
90     case 32:
91       return OverloadKind::I32;
92     case 64:
93       return OverloadKind::I64;
94     default:
95       llvm_unreachable("invalid overload type");
96       return OverloadKind::VOID;
97     }
98   }
99   case Type::PointerTyID:
100     return OverloadKind::UserDefineType;
101   case Type::StructTyID:
102     return OverloadKind::ObjectType;
103   default:
104     llvm_unreachable("invalid overload type");
105     return OverloadKind::VOID;
106   }
107 }
108 
getTypeName(OverloadKind Kind,Type * Ty)109 static std::string getTypeName(OverloadKind Kind, Type *Ty) {
110   if (Kind < OverloadKind::UserDefineType) {
111     return getOverloadTypeName(Kind);
112   } else if (Kind == OverloadKind::UserDefineType) {
113     StructType *ST = cast<StructType>(Ty);
114     return ST->getStructName().str();
115   } else if (Kind == OverloadKind::ObjectType) {
116     StructType *ST = cast<StructType>(Ty);
117     return ST->getStructName().str();
118   } else {
119     std::string Str;
120     raw_string_ostream OS(Str);
121     Ty->print(OS);
122     return OS.str();
123   }
124 }
125 
126 // Static properties.
127 struct OpCodeProperty {
128   dxil::OpCode OpCode;
129   // Offset in DXILOpCodeNameTable.
130   unsigned OpCodeNameOffset;
131   dxil::OpCodeClass OpCodeClass;
132   // Offset in DXILOpCodeClassNameTable.
133   unsigned OpCodeClassNameOffset;
134   uint16_t OverloadTys;
135   llvm::Attribute::AttrKind FuncAttr;
136   int OverloadParamIndex;        // parameter index which control the overload.
137                                  // When < 0, should be only 1 overload type.
138   unsigned NumOfParameters;      // Number of parameters include return value.
139   unsigned ParameterTableOffset; // Offset in ParameterTable.
140 };
141 
142 // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
143 // getOpCodeParameterKind which generated by tableGen.
144 #define DXIL_OP_OPERATION_TABLE
145 #include "DXILOperation.inc"
146 #undef DXIL_OP_OPERATION_TABLE
147 
constructOverloadName(OverloadKind Kind,Type * Ty,const OpCodeProperty & Prop)148 static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
149                                          const OpCodeProperty &Prop) {
150   if (Kind == OverloadKind::VOID) {
151     return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
152   }
153   return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
154           getTypeName(Kind, Ty))
155       .str();
156 }
157 
constructOverloadTypeName(OverloadKind Kind,StringRef TypeName)158 static std::string constructOverloadTypeName(OverloadKind Kind,
159                                              StringRef TypeName) {
160   if (Kind == OverloadKind::VOID)
161     return TypeName.str();
162 
163   assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
164   return (Twine(TypeName) + getOverloadTypeName(Kind)).str();
165 }
166 
getOrCreateStructType(StringRef Name,ArrayRef<Type * > EltTys,LLVMContext & Ctx)167 static StructType *getOrCreateStructType(StringRef Name,
168                                          ArrayRef<Type *> EltTys,
169                                          LLVMContext &Ctx) {
170   StructType *ST = StructType::getTypeByName(Ctx, Name);
171   if (ST)
172     return ST;
173 
174   return StructType::create(Ctx, EltTys, Name);
175 }
176 
getResRetType(Type * OverloadTy,LLVMContext & Ctx)177 static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {
178   OverloadKind Kind = getOverloadKind(OverloadTy);
179   std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
180   Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,
181                          Type::getInt32Ty(Ctx)};
182   return getOrCreateStructType(TypeName, FieldTypes, Ctx);
183 }
184 
getHandleType(LLVMContext & Ctx)185 static StructType *getHandleType(LLVMContext &Ctx) {
186   return getOrCreateStructType("dx.types.Handle", PointerType::getUnqual(Ctx),
187                                Ctx);
188 }
189 
getTypeFromParameterKind(ParameterKind Kind,Type * OverloadTy)190 static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {
191   auto &Ctx = OverloadTy->getContext();
192   switch (Kind) {
193   case ParameterKind::Void:
194     return Type::getVoidTy(Ctx);
195   case ParameterKind::Half:
196     return Type::getHalfTy(Ctx);
197   case ParameterKind::Float:
198     return Type::getFloatTy(Ctx);
199   case ParameterKind::Double:
200     return Type::getDoubleTy(Ctx);
201   case ParameterKind::I1:
202     return Type::getInt1Ty(Ctx);
203   case ParameterKind::I8:
204     return Type::getInt8Ty(Ctx);
205   case ParameterKind::I16:
206     return Type::getInt16Ty(Ctx);
207   case ParameterKind::I32:
208     return Type::getInt32Ty(Ctx);
209   case ParameterKind::I64:
210     return Type::getInt64Ty(Ctx);
211   case ParameterKind::Overload:
212     return OverloadTy;
213   case ParameterKind::ResourceRet:
214     return getResRetType(OverloadTy, Ctx);
215   case ParameterKind::DXILHandle:
216     return getHandleType(Ctx);
217   default:
218     break;
219   }
220   llvm_unreachable("Invalid parameter kind");
221   return nullptr;
222 }
223 
224 /// Construct DXIL function type. This is the type of a function with
225 /// the following prototype
226 ///     OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)
227 /// <param-types> are constructed from types in Prop.
228 /// \param Prop  Structure containing DXIL Operation properties based on
229 ///               its specification in DXIL.td.
230 /// \param OverloadTy Return type to be used to construct DXIL function type.
getDXILOpFunctionType(const OpCodeProperty * Prop,Type * ReturnTy,Type * OverloadTy)231 static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,
232                                            Type *ReturnTy, Type *OverloadTy) {
233   SmallVector<Type *> ArgTys;
234 
235   auto ParamKinds = getOpCodeParameterKind(*Prop);
236 
237   // Add ReturnTy as return type of the function
238   ArgTys.emplace_back(ReturnTy);
239 
240   // Add DXIL Opcode value type viz., Int32 as first argument
241   ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext()));
242 
243   // Add DXIL Operation parameter types as specified in DXIL properties
244   for (unsigned I = 0; I < Prop->NumOfParameters; ++I) {
245     ParameterKind Kind = ParamKinds[I];
246     ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy));
247   }
248   return FunctionType::get(
249       ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);
250 }
251 
252 namespace llvm {
253 namespace dxil {
254 
createDXILOpCall(dxil::OpCode OpCode,Type * ReturnTy,Type * OverloadTy,SmallVector<Value * > Args)255 CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,
256                                           Type *OverloadTy,
257                                           SmallVector<Value *> Args) {
258   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
259 
260   OverloadKind Kind = getOverloadKind(OverloadTy);
261   if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {
262     report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);
263   }
264 
265   std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop);
266   FunctionCallee DXILFn;
267   // Get the function with name DXILFnName, if one exists
268   if (auto *Func = M.getFunction(DXILFnName)) {
269     DXILFn = FunctionCallee(Func);
270   } else {
271     // Construct and add a function with name DXILFnName
272     FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);
273     DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
274   }
275 
276   return B.CreateCall(DXILFn, Args);
277 }
278 
getOverloadTy(dxil::OpCode OpCode,FunctionType * FT)279 Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {
280 
281   const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
282   // If DXIL Op has no overload parameter, just return the
283   // precise return type specified.
284   if (Prop->OverloadParamIndex < 0) {
285     auto &Ctx = FT->getContext();
286     switch (Prop->OverloadTys) {
287     case OverloadKind::VOID:
288       return Type::getVoidTy(Ctx);
289     case OverloadKind::HALF:
290       return Type::getHalfTy(Ctx);
291     case OverloadKind::FLOAT:
292       return Type::getFloatTy(Ctx);
293     case OverloadKind::DOUBLE:
294       return Type::getDoubleTy(Ctx);
295     case OverloadKind::I1:
296       return Type::getInt1Ty(Ctx);
297     case OverloadKind::I8:
298       return Type::getInt8Ty(Ctx);
299     case OverloadKind::I16:
300       return Type::getInt16Ty(Ctx);
301     case OverloadKind::I32:
302       return Type::getInt32Ty(Ctx);
303     case OverloadKind::I64:
304       return Type::getInt64Ty(Ctx);
305     default:
306       llvm_unreachable("invalid overload type");
307       return nullptr;
308     }
309   }
310 
311   // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().
312   Type *OverloadType = FT->getReturnType();
313   if (Prop->OverloadParamIndex != 0) {
314     // Skip Return Type.
315     OverloadType = FT->getParamType(Prop->OverloadParamIndex - 1);
316   }
317 
318   auto ParamKinds = getOpCodeParameterKind(*Prop);
319   auto Kind = ParamKinds[Prop->OverloadParamIndex];
320   // For ResRet and CBufferRet, OverloadTy is in field of StructType.
321   if (Kind == ParameterKind::CBufferRet ||
322       Kind == ParameterKind::ResourceRet) {
323     auto *ST = cast<StructType>(OverloadType);
324     OverloadType = ST->getElementType(0);
325   }
326   return OverloadType;
327 }
328 
getOpCodeName(dxil::OpCode DXILOp)329 const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) {
330   return ::getOpCodeName(DXILOp);
331 }
332 } // namespace dxil
333 } // namespace llvm
334