1 //===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains class to help build DXIL op functions. 10 //===----------------------------------------------------------------------===// 11 12 #include "DXILOpBuilder.h" 13 #include "DXILConstants.h" 14 #include "llvm/IR/IRBuilder.h" 15 #include "llvm/IR/Module.h" 16 #include "llvm/Support/DXILABI.h" 17 #include "llvm/Support/ErrorHandling.h" 18 19 using namespace llvm; 20 using namespace llvm::dxil; 21 22 constexpr StringLiteral DXILOpNamePrefix = "dx.op."; 23 24 namespace { 25 26 enum OverloadKind : uint16_t { 27 VOID = 1, 28 HALF = 1 << 1, 29 FLOAT = 1 << 2, 30 DOUBLE = 1 << 3, 31 I1 = 1 << 4, 32 I8 = 1 << 5, 33 I16 = 1 << 6, 34 I32 = 1 << 7, 35 I64 = 1 << 8, 36 UserDefineType = 1 << 9, 37 ObjectType = 1 << 10, 38 }; 39 40 } // namespace 41 42 static const char *getOverloadTypeName(OverloadKind Kind) { 43 switch (Kind) { 44 case OverloadKind::HALF: 45 return "f16"; 46 case OverloadKind::FLOAT: 47 return "f32"; 48 case OverloadKind::DOUBLE: 49 return "f64"; 50 case OverloadKind::I1: 51 return "i1"; 52 case OverloadKind::I8: 53 return "i8"; 54 case OverloadKind::I16: 55 return "i16"; 56 case OverloadKind::I32: 57 return "i32"; 58 case OverloadKind::I64: 59 return "i64"; 60 case OverloadKind::VOID: 61 case OverloadKind::ObjectType: 62 case OverloadKind::UserDefineType: 63 break; 64 } 65 llvm_unreachable("invalid overload type for name"); 66 return "void"; 67 } 68 69 static OverloadKind getOverloadKind(Type *Ty) { 70 Type::TypeID T = Ty->getTypeID(); 71 switch (T) { 72 case Type::VoidTyID: 73 return OverloadKind::VOID; 74 case Type::HalfTyID: 75 return OverloadKind::HALF; 76 case Type::FloatTyID: 77 return OverloadKind::FLOAT; 78 case Type::DoubleTyID: 79 return OverloadKind::DOUBLE; 80 case Type::IntegerTyID: { 81 IntegerType *ITy = cast<IntegerType>(Ty); 82 unsigned Bits = ITy->getBitWidth(); 83 switch (Bits) { 84 case 1: 85 return OverloadKind::I1; 86 case 8: 87 return OverloadKind::I8; 88 case 16: 89 return OverloadKind::I16; 90 case 32: 91 return OverloadKind::I32; 92 case 64: 93 return OverloadKind::I64; 94 default: 95 llvm_unreachable("invalid overload type"); 96 return OverloadKind::VOID; 97 } 98 } 99 case Type::PointerTyID: 100 return OverloadKind::UserDefineType; 101 case Type::StructTyID: 102 return OverloadKind::ObjectType; 103 default: 104 llvm_unreachable("invalid overload type"); 105 return OverloadKind::VOID; 106 } 107 } 108 109 static std::string getTypeName(OverloadKind Kind, Type *Ty) { 110 if (Kind < OverloadKind::UserDefineType) { 111 return getOverloadTypeName(Kind); 112 } else if (Kind == OverloadKind::UserDefineType) { 113 StructType *ST = cast<StructType>(Ty); 114 return ST->getStructName().str(); 115 } else if (Kind == OverloadKind::ObjectType) { 116 StructType *ST = cast<StructType>(Ty); 117 return ST->getStructName().str(); 118 } else { 119 std::string Str; 120 raw_string_ostream OS(Str); 121 Ty->print(OS); 122 return OS.str(); 123 } 124 } 125 126 // Static properties. 127 struct OpCodeProperty { 128 dxil::OpCode OpCode; 129 // Offset in DXILOpCodeNameTable. 130 unsigned OpCodeNameOffset; 131 dxil::OpCodeClass OpCodeClass; 132 // Offset in DXILOpCodeClassNameTable. 133 unsigned OpCodeClassNameOffset; 134 uint16_t OverloadTys; 135 llvm::Attribute::AttrKind FuncAttr; 136 int OverloadParamIndex; // parameter index which control the overload. 137 // When < 0, should be only 1 overload type. 138 unsigned NumOfParameters; // Number of parameters include return value. 139 unsigned ParameterTableOffset; // Offset in ParameterTable. 140 }; 141 142 // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and 143 // getOpCodeParameterKind which generated by tableGen. 144 #define DXIL_OP_OPERATION_TABLE 145 #include "DXILOperation.inc" 146 #undef DXIL_OP_OPERATION_TABLE 147 148 static std::string constructOverloadName(OverloadKind Kind, Type *Ty, 149 const OpCodeProperty &Prop) { 150 if (Kind == OverloadKind::VOID) { 151 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); 152 } 153 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + 154 getTypeName(Kind, Ty)) 155 .str(); 156 } 157 158 static std::string constructOverloadTypeName(OverloadKind Kind, 159 StringRef TypeName) { 160 if (Kind == OverloadKind::VOID) 161 return TypeName.str(); 162 163 assert(Kind < OverloadKind::UserDefineType && "invalid overload kind"); 164 return (Twine(TypeName) + getOverloadTypeName(Kind)).str(); 165 } 166 167 static StructType *getOrCreateStructType(StringRef Name, 168 ArrayRef<Type *> EltTys, 169 LLVMContext &Ctx) { 170 StructType *ST = StructType::getTypeByName(Ctx, Name); 171 if (ST) 172 return ST; 173 174 return StructType::create(Ctx, EltTys, Name); 175 } 176 177 static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) { 178 OverloadKind Kind = getOverloadKind(OverloadTy); 179 std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet."); 180 Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy, 181 Type::getInt32Ty(Ctx)}; 182 return getOrCreateStructType(TypeName, FieldTypes, Ctx); 183 } 184 185 static StructType *getHandleType(LLVMContext &Ctx) { 186 return getOrCreateStructType("dx.types.Handle", PointerType::getUnqual(Ctx), 187 Ctx); 188 } 189 190 static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) { 191 auto &Ctx = OverloadTy->getContext(); 192 switch (Kind) { 193 case ParameterKind::Void: 194 return Type::getVoidTy(Ctx); 195 case ParameterKind::Half: 196 return Type::getHalfTy(Ctx); 197 case ParameterKind::Float: 198 return Type::getFloatTy(Ctx); 199 case ParameterKind::Double: 200 return Type::getDoubleTy(Ctx); 201 case ParameterKind::I1: 202 return Type::getInt1Ty(Ctx); 203 case ParameterKind::I8: 204 return Type::getInt8Ty(Ctx); 205 case ParameterKind::I16: 206 return Type::getInt16Ty(Ctx); 207 case ParameterKind::I32: 208 return Type::getInt32Ty(Ctx); 209 case ParameterKind::I64: 210 return Type::getInt64Ty(Ctx); 211 case ParameterKind::Overload: 212 return OverloadTy; 213 case ParameterKind::ResourceRet: 214 return getResRetType(OverloadTy, Ctx); 215 case ParameterKind::DXILHandle: 216 return getHandleType(Ctx); 217 default: 218 break; 219 } 220 llvm_unreachable("Invalid parameter kind"); 221 return nullptr; 222 } 223 224 /// Construct DXIL function type. This is the type of a function with 225 /// the following prototype 226 /// OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>) 227 /// <param-types> are constructed from types in Prop. 228 /// \param Prop Structure containing DXIL Operation properties based on 229 /// its specification in DXIL.td. 230 /// \param OverloadTy Return type to be used to construct DXIL function type. 231 static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, 232 Type *ReturnTy, Type *OverloadTy) { 233 SmallVector<Type *> ArgTys; 234 235 auto ParamKinds = getOpCodeParameterKind(*Prop); 236 237 // Add ReturnTy as return type of the function 238 ArgTys.emplace_back(ReturnTy); 239 240 // Add DXIL Opcode value type viz., Int32 as first argument 241 ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext())); 242 243 // Add DXIL Operation parameter types as specified in DXIL properties 244 for (unsigned I = 0; I < Prop->NumOfParameters; ++I) { 245 ParameterKind Kind = ParamKinds[I]; 246 ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy)); 247 } 248 return FunctionType::get( 249 ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false); 250 } 251 252 namespace llvm { 253 namespace dxil { 254 255 CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, 256 Type *OverloadTy, 257 SmallVector<Value *> Args) { 258 const OpCodeProperty *Prop = getOpCodeProperty(OpCode); 259 260 OverloadKind Kind = getOverloadKind(OverloadTy); 261 if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { 262 report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false); 263 } 264 265 std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop); 266 FunctionCallee DXILFn; 267 // Get the function with name DXILFnName, if one exists 268 if (auto *Func = M.getFunction(DXILFnName)) { 269 DXILFn = FunctionCallee(Func); 270 } else { 271 // Construct and add a function with name DXILFnName 272 FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy); 273 DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT); 274 } 275 276 return B.CreateCall(DXILFn, Args); 277 } 278 279 Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) { 280 281 const OpCodeProperty *Prop = getOpCodeProperty(OpCode); 282 // If DXIL Op has no overload parameter, just return the 283 // precise return type specified. 284 if (Prop->OverloadParamIndex < 0) { 285 auto &Ctx = FT->getContext(); 286 switch (Prop->OverloadTys) { 287 case OverloadKind::VOID: 288 return Type::getVoidTy(Ctx); 289 case OverloadKind::HALF: 290 return Type::getHalfTy(Ctx); 291 case OverloadKind::FLOAT: 292 return Type::getFloatTy(Ctx); 293 case OverloadKind::DOUBLE: 294 return Type::getDoubleTy(Ctx); 295 case OverloadKind::I1: 296 return Type::getInt1Ty(Ctx); 297 case OverloadKind::I8: 298 return Type::getInt8Ty(Ctx); 299 case OverloadKind::I16: 300 return Type::getInt16Ty(Ctx); 301 case OverloadKind::I32: 302 return Type::getInt32Ty(Ctx); 303 case OverloadKind::I64: 304 return Type::getInt64Ty(Ctx); 305 default: 306 llvm_unreachable("invalid overload type"); 307 return nullptr; 308 } 309 } 310 311 // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType(). 312 Type *OverloadType = FT->getReturnType(); 313 if (Prop->OverloadParamIndex != 0) { 314 // Skip Return Type. 315 OverloadType = FT->getParamType(Prop->OverloadParamIndex - 1); 316 } 317 318 auto ParamKinds = getOpCodeParameterKind(*Prop); 319 auto Kind = ParamKinds[Prop->OverloadParamIndex]; 320 // For ResRet and CBufferRet, OverloadTy is in field of StructType. 321 if (Kind == ParameterKind::CBufferRet || 322 Kind == ParameterKind::ResourceRet) { 323 auto *ST = cast<StructType>(OverloadType); 324 OverloadType = ST->getElementType(0); 325 } 326 return OverloadType; 327 } 328 329 const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) { 330 return ::getOpCodeName(DXILOp); 331 } 332 } // namespace dxil 333 } // namespace llvm 334