1 //===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains class to help build DXIL op functions. 10 //===----------------------------------------------------------------------===// 11 12 #include "DXILOpBuilder.h" 13 #include "DXILConstants.h" 14 #include "llvm/IR/IRBuilder.h" 15 #include "llvm/IR/Module.h" 16 #include "llvm/Support/DXILOperationCommon.h" 17 #include "llvm/Support/ErrorHandling.h" 18 19 using namespace llvm; 20 using namespace llvm::dxil; 21 22 constexpr StringLiteral DXILOpNamePrefix = "dx.op."; 23 24 namespace { 25 26 enum OverloadKind : uint16_t { 27 VOID = 1, 28 HALF = 1 << 1, 29 FLOAT = 1 << 2, 30 DOUBLE = 1 << 3, 31 I1 = 1 << 4, 32 I8 = 1 << 5, 33 I16 = 1 << 6, 34 I32 = 1 << 7, 35 I64 = 1 << 8, 36 UserDefineType = 1 << 9, 37 ObjectType = 1 << 10, 38 }; 39 40 } // namespace 41 42 static const char *getOverloadTypeName(OverloadKind Kind) { 43 switch (Kind) { 44 case OverloadKind::HALF: 45 return "f16"; 46 case OverloadKind::FLOAT: 47 return "f32"; 48 case OverloadKind::DOUBLE: 49 return "f64"; 50 case OverloadKind::I1: 51 return "i1"; 52 case OverloadKind::I8: 53 return "i8"; 54 case OverloadKind::I16: 55 return "i16"; 56 case OverloadKind::I32: 57 return "i32"; 58 case OverloadKind::I64: 59 return "i64"; 60 case OverloadKind::VOID: 61 case OverloadKind::ObjectType: 62 case OverloadKind::UserDefineType: 63 break; 64 } 65 llvm_unreachable("invalid overload type for name"); 66 return "void"; 67 } 68 69 static OverloadKind getOverloadKind(Type *Ty) { 70 Type::TypeID T = Ty->getTypeID(); 71 switch (T) { 72 case Type::VoidTyID: 73 return OverloadKind::VOID; 74 case Type::HalfTyID: 75 return OverloadKind::HALF; 76 case Type::FloatTyID: 77 return OverloadKind::FLOAT; 78 case Type::DoubleTyID: 79 return OverloadKind::DOUBLE; 80 case Type::IntegerTyID: { 81 IntegerType *ITy = cast<IntegerType>(Ty); 82 unsigned Bits = ITy->getBitWidth(); 83 switch (Bits) { 84 case 1: 85 return OverloadKind::I1; 86 case 8: 87 return OverloadKind::I8; 88 case 16: 89 return OverloadKind::I16; 90 case 32: 91 return OverloadKind::I32; 92 case 64: 93 return OverloadKind::I64; 94 default: 95 llvm_unreachable("invalid overload type"); 96 return OverloadKind::VOID; 97 } 98 } 99 case Type::PointerTyID: 100 return OverloadKind::UserDefineType; 101 case Type::StructTyID: 102 return OverloadKind::ObjectType; 103 default: 104 llvm_unreachable("invalid overload type"); 105 return OverloadKind::VOID; 106 } 107 } 108 109 static std::string getTypeName(OverloadKind Kind, Type *Ty) { 110 if (Kind < OverloadKind::UserDefineType) { 111 return getOverloadTypeName(Kind); 112 } else if (Kind == OverloadKind::UserDefineType) { 113 StructType *ST = cast<StructType>(Ty); 114 return ST->getStructName().str(); 115 } else if (Kind == OverloadKind::ObjectType) { 116 StructType *ST = cast<StructType>(Ty); 117 return ST->getStructName().str(); 118 } else { 119 std::string Str; 120 raw_string_ostream OS(Str); 121 Ty->print(OS); 122 return OS.str(); 123 } 124 } 125 126 // Static properties. 127 struct OpCodeProperty { 128 dxil::OpCode OpCode; 129 // Offset in DXILOpCodeNameTable. 130 unsigned OpCodeNameOffset; 131 dxil::OpCodeClass OpCodeClass; 132 // Offset in DXILOpCodeClassNameTable. 133 unsigned OpCodeClassNameOffset; 134 uint16_t OverloadTys; 135 llvm::Attribute::AttrKind FuncAttr; 136 int OverloadParamIndex; // parameter index which control the overload. 137 // When < 0, should be only 1 overload type. 138 unsigned NumOfParameters; // Number of parameters include return value. 139 unsigned ParameterTableOffset; // Offset in ParameterTable. 140 }; 141 142 // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and 143 // getOpCodeParameterKind which generated by tableGen. 144 #define DXIL_OP_OPERATION_TABLE 145 #include "DXILOperation.inc" 146 #undef DXIL_OP_OPERATION_TABLE 147 148 static std::string constructOverloadName(OverloadKind Kind, Type *Ty, 149 const OpCodeProperty &Prop) { 150 if (Kind == OverloadKind::VOID) { 151 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); 152 } 153 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + 154 getTypeName(Kind, Ty)) 155 .str(); 156 } 157 158 static std::string constructOverloadTypeName(OverloadKind Kind, 159 StringRef TypeName) { 160 if (Kind == OverloadKind::VOID) 161 return TypeName.str(); 162 163 assert(Kind < OverloadKind::UserDefineType && "invalid overload kind"); 164 return (Twine(TypeName) + getOverloadTypeName(Kind)).str(); 165 } 166 167 static StructType *getOrCreateStructType(StringRef Name, 168 ArrayRef<Type *> EltTys, 169 LLVMContext &Ctx) { 170 StructType *ST = StructType::getTypeByName(Ctx, Name); 171 if (ST) 172 return ST; 173 174 return StructType::create(Ctx, EltTys, Name); 175 } 176 177 static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) { 178 OverloadKind Kind = getOverloadKind(OverloadTy); 179 std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet."); 180 Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy, 181 Type::getInt32Ty(Ctx)}; 182 return getOrCreateStructType(TypeName, FieldTypes, Ctx); 183 } 184 185 static StructType *getHandleType(LLVMContext &Ctx) { 186 return getOrCreateStructType("dx.types.Handle", Type::getInt8PtrTy(Ctx), Ctx); 187 } 188 189 static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) { 190 auto &Ctx = OverloadTy->getContext(); 191 switch (Kind) { 192 case ParameterKind::VOID: 193 return Type::getVoidTy(Ctx); 194 case ParameterKind::HALF: 195 return Type::getHalfTy(Ctx); 196 case ParameterKind::FLOAT: 197 return Type::getFloatTy(Ctx); 198 case ParameterKind::DOUBLE: 199 return Type::getDoubleTy(Ctx); 200 case ParameterKind::I1: 201 return Type::getInt1Ty(Ctx); 202 case ParameterKind::I8: 203 return Type::getInt8Ty(Ctx); 204 case ParameterKind::I16: 205 return Type::getInt16Ty(Ctx); 206 case ParameterKind::I32: 207 return Type::getInt32Ty(Ctx); 208 case ParameterKind::I64: 209 return Type::getInt64Ty(Ctx); 210 case ParameterKind::OVERLOAD: 211 return OverloadTy; 212 case ParameterKind::RESOURCE_RET: 213 return getResRetType(OverloadTy, Ctx); 214 case ParameterKind::DXIL_HANDLE: 215 return getHandleType(Ctx); 216 default: 217 break; 218 } 219 llvm_unreachable("Invalid parameter kind"); 220 return nullptr; 221 } 222 223 static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, 224 Type *OverloadTy) { 225 SmallVector<Type *> ArgTys; 226 227 auto ParamKinds = getOpCodeParameterKind(*Prop); 228 229 for (unsigned I = 0; I < Prop->NumOfParameters; ++I) { 230 ParameterKind Kind = ParamKinds[I]; 231 ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy)); 232 } 233 return FunctionType::get( 234 ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false); 235 } 236 237 static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp, 238 Type *OverloadTy, Module &M) { 239 const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); 240 241 OverloadKind Kind = getOverloadKind(OverloadTy); 242 // FIXME: find the issue and report error in clang instead of check it in 243 // backend. 244 if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { 245 llvm_unreachable("invalid overload"); 246 } 247 248 std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); 249 // Dependent on name to dedup. 250 if (auto *Fn = M.getFunction(FnName)) 251 return FunctionCallee(Fn); 252 253 FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy); 254 return M.getOrInsertFunction(FnName, DXILOpFT); 255 } 256 257 namespace llvm { 258 namespace dxil { 259 260 CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy, 261 llvm::iterator_range<Use *> Args) { 262 auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M); 263 SmallVector<Value *> FullArgs; 264 FullArgs.emplace_back(B.getInt32((int32_t)OpCode)); 265 FullArgs.append(Args.begin(), Args.end()); 266 return B.CreateCall(Fn, FullArgs); 267 } 268 269 Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT, 270 bool NoOpCodeParam) { 271 272 const OpCodeProperty *Prop = getOpCodeProperty(OpCode); 273 if (Prop->OverloadParamIndex < 0) { 274 auto &Ctx = FT->getContext(); 275 // When only has 1 overload type, just return it. 276 switch (Prop->OverloadTys) { 277 case OverloadKind::VOID: 278 return Type::getVoidTy(Ctx); 279 case OverloadKind::HALF: 280 return Type::getHalfTy(Ctx); 281 case OverloadKind::FLOAT: 282 return Type::getFloatTy(Ctx); 283 case OverloadKind::DOUBLE: 284 return Type::getDoubleTy(Ctx); 285 case OverloadKind::I1: 286 return Type::getInt1Ty(Ctx); 287 case OverloadKind::I8: 288 return Type::getInt8Ty(Ctx); 289 case OverloadKind::I16: 290 return Type::getInt16Ty(Ctx); 291 case OverloadKind::I32: 292 return Type::getInt32Ty(Ctx); 293 case OverloadKind::I64: 294 return Type::getInt64Ty(Ctx); 295 default: 296 llvm_unreachable("invalid overload type"); 297 return nullptr; 298 } 299 } 300 301 // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType(). 302 Type *OverloadType = FT->getReturnType(); 303 if (Prop->OverloadParamIndex != 0) { 304 // Skip Return Type and Type for DXIL opcode. 305 const unsigned SkipedParam = NoOpCodeParam ? 2 : 1; 306 OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam); 307 } 308 309 auto ParamKinds = getOpCodeParameterKind(*Prop); 310 auto Kind = ParamKinds[Prop->OverloadParamIndex]; 311 // For ResRet and CBufferRet, OverloadTy is in field of StructType. 312 if (Kind == ParameterKind::CBUFFER_RET || 313 Kind == ParameterKind::RESOURCE_RET) { 314 auto *ST = cast<StructType>(OverloadType); 315 OverloadType = ST->getElementType(0); 316 } 317 return OverloadType; 318 } 319 320 const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) { 321 return ::getOpCodeName(DXILOp); 322 } 323 } // namespace dxil 324 } // namespace llvm 325