1 //===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file This file contains class to help build DXIL op functions. 10 //===----------------------------------------------------------------------===// 11 12 #include "DXILOpBuilder.h" 13 #include "DXILConstants.h" 14 #include "llvm/IR/IRBuilder.h" 15 #include "llvm/IR/Module.h" 16 #include "llvm/Support/DXILOperationCommon.h" 17 #include "llvm/Support/ErrorHandling.h" 18 19 using namespace llvm; 20 using namespace llvm::dxil; 21 22 constexpr StringLiteral DXILOpNamePrefix = "dx.op."; 23 24 namespace { 25 26 enum OverloadKind : uint16_t { 27 VOID = 1, 28 HALF = 1 << 1, 29 FLOAT = 1 << 2, 30 DOUBLE = 1 << 3, 31 I1 = 1 << 4, 32 I8 = 1 << 5, 33 I16 = 1 << 6, 34 I32 = 1 << 7, 35 I64 = 1 << 8, 36 UserDefineType = 1 << 9, 37 ObjectType = 1 << 10, 38 }; 39 40 } // namespace 41 42 static const char *getOverloadTypeName(OverloadKind Kind) { 43 switch (Kind) { 44 case OverloadKind::HALF: 45 return "f16"; 46 case OverloadKind::FLOAT: 47 return "f32"; 48 case OverloadKind::DOUBLE: 49 return "f64"; 50 case OverloadKind::I1: 51 return "i1"; 52 case OverloadKind::I8: 53 return "i8"; 54 case OverloadKind::I16: 55 return "i16"; 56 case OverloadKind::I32: 57 return "i32"; 58 case OverloadKind::I64: 59 return "i64"; 60 case OverloadKind::VOID: 61 case OverloadKind::ObjectType: 62 case OverloadKind::UserDefineType: 63 break; 64 } 65 llvm_unreachable("invalid overload type for name"); 66 return "void"; 67 } 68 69 static OverloadKind getOverloadKind(Type *Ty) { 70 Type::TypeID T = Ty->getTypeID(); 71 switch (T) { 72 case Type::VoidTyID: 73 return OverloadKind::VOID; 74 case Type::HalfTyID: 75 return OverloadKind::HALF; 76 case Type::FloatTyID: 77 return OverloadKind::FLOAT; 78 case Type::DoubleTyID: 79 return OverloadKind::DOUBLE; 80 case Type::IntegerTyID: { 81 IntegerType *ITy = cast<IntegerType>(Ty); 82 unsigned Bits = ITy->getBitWidth(); 83 switch (Bits) { 84 case 1: 85 return OverloadKind::I1; 86 case 8: 87 return OverloadKind::I8; 88 case 16: 89 return OverloadKind::I16; 90 case 32: 91 return OverloadKind::I32; 92 case 64: 93 return OverloadKind::I64; 94 default: 95 llvm_unreachable("invalid overload type"); 96 return OverloadKind::VOID; 97 } 98 } 99 case Type::PointerTyID: 100 return OverloadKind::UserDefineType; 101 case Type::StructTyID: 102 return OverloadKind::ObjectType; 103 default: 104 llvm_unreachable("invalid overload type"); 105 return OverloadKind::VOID; 106 } 107 } 108 109 static std::string getTypeName(OverloadKind Kind, Type *Ty) { 110 if (Kind < OverloadKind::UserDefineType) { 111 return getOverloadTypeName(Kind); 112 } else if (Kind == OverloadKind::UserDefineType) { 113 StructType *ST = cast<StructType>(Ty); 114 return ST->getStructName().str(); 115 } else if (Kind == OverloadKind::ObjectType) { 116 StructType *ST = cast<StructType>(Ty); 117 return ST->getStructName().str(); 118 } else { 119 std::string Str; 120 raw_string_ostream OS(Str); 121 Ty->print(OS); 122 return OS.str(); 123 } 124 } 125 126 // Static properties. 127 struct OpCodeProperty { 128 dxil::OpCode OpCode; 129 // Offset in DXILOpCodeNameTable. 130 unsigned OpCodeNameOffset; 131 dxil::OpCodeClass OpCodeClass; 132 // Offset in DXILOpCodeClassNameTable. 133 unsigned OpCodeClassNameOffset; 134 uint16_t OverloadTys; 135 llvm::Attribute::AttrKind FuncAttr; 136 int OverloadParamIndex; // parameter index which control the overload. 137 // When < 0, should be only 1 overload type. 138 unsigned NumOfParameters; // Number of parameters include return value. 139 unsigned ParameterTableOffset; // Offset in ParameterTable. 140 }; 141 142 // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and 143 // getOpCodeParameterKind which generated by tableGen. 144 #define DXIL_OP_OPERATION_TABLE 145 #include "DXILOperation.inc" 146 #undef DXIL_OP_OPERATION_TABLE 147 148 static std::string constructOverloadName(OverloadKind Kind, Type *Ty, 149 const OpCodeProperty &Prop) { 150 if (Kind == OverloadKind::VOID) { 151 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); 152 } 153 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + 154 getTypeName(Kind, Ty)) 155 .str(); 156 } 157 158 static std::string constructOverloadTypeName(OverloadKind Kind, 159 StringRef TypeName) { 160 if (Kind == OverloadKind::VOID) 161 return TypeName.str(); 162 163 assert(Kind < OverloadKind::UserDefineType && "invalid overload kind"); 164 return (Twine(TypeName) + getOverloadTypeName(Kind)).str(); 165 } 166 167 static StructType *getOrCreateStructType(StringRef Name, 168 ArrayRef<Type *> EltTys, 169 LLVMContext &Ctx) { 170 StructType *ST = StructType::getTypeByName(Ctx, Name); 171 if (ST) 172 return ST; 173 174 return StructType::create(Ctx, EltTys, Name); 175 } 176 177 static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) { 178 OverloadKind Kind = getOverloadKind(OverloadTy); 179 std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet."); 180 Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy, 181 Type::getInt32Ty(Ctx)}; 182 return getOrCreateStructType(TypeName, FieldTypes, Ctx); 183 } 184 185 static StructType *getHandleType(LLVMContext &Ctx) { 186 return getOrCreateStructType("dx.types.Handle", PointerType::getUnqual(Ctx), 187 Ctx); 188 } 189 190 static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) { 191 auto &Ctx = OverloadTy->getContext(); 192 switch (Kind) { 193 case ParameterKind::VOID: 194 return Type::getVoidTy(Ctx); 195 case ParameterKind::HALF: 196 return Type::getHalfTy(Ctx); 197 case ParameterKind::FLOAT: 198 return Type::getFloatTy(Ctx); 199 case ParameterKind::DOUBLE: 200 return Type::getDoubleTy(Ctx); 201 case ParameterKind::I1: 202 return Type::getInt1Ty(Ctx); 203 case ParameterKind::I8: 204 return Type::getInt8Ty(Ctx); 205 case ParameterKind::I16: 206 return Type::getInt16Ty(Ctx); 207 case ParameterKind::I32: 208 return Type::getInt32Ty(Ctx); 209 case ParameterKind::I64: 210 return Type::getInt64Ty(Ctx); 211 case ParameterKind::OVERLOAD: 212 return OverloadTy; 213 case ParameterKind::RESOURCE_RET: 214 return getResRetType(OverloadTy, Ctx); 215 case ParameterKind::DXIL_HANDLE: 216 return getHandleType(Ctx); 217 default: 218 break; 219 } 220 llvm_unreachable("Invalid parameter kind"); 221 return nullptr; 222 } 223 224 static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, 225 Type *OverloadTy) { 226 SmallVector<Type *> ArgTys; 227 228 auto ParamKinds = getOpCodeParameterKind(*Prop); 229 230 for (unsigned I = 0; I < Prop->NumOfParameters; ++I) { 231 ParameterKind Kind = ParamKinds[I]; 232 ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy)); 233 } 234 return FunctionType::get( 235 ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false); 236 } 237 238 static FunctionCallee getOrCreateDXILOpFunction(dxil::OpCode DXILOp, 239 Type *OverloadTy, Module &M) { 240 const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); 241 242 OverloadKind Kind = getOverloadKind(OverloadTy); 243 // FIXME: find the issue and report error in clang instead of check it in 244 // backend. 245 if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { 246 llvm_unreachable("invalid overload"); 247 } 248 249 std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); 250 // Dependent on name to dedup. 251 if (auto *Fn = M.getFunction(FnName)) 252 return FunctionCallee(Fn); 253 254 FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy); 255 return M.getOrInsertFunction(FnName, DXILOpFT); 256 } 257 258 namespace llvm { 259 namespace dxil { 260 261 CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *OverloadTy, 262 llvm::iterator_range<Use *> Args) { 263 auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M); 264 SmallVector<Value *> FullArgs; 265 FullArgs.emplace_back(B.getInt32((int32_t)OpCode)); 266 FullArgs.append(Args.begin(), Args.end()); 267 return B.CreateCall(Fn, FullArgs); 268 } 269 270 Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT, 271 bool NoOpCodeParam) { 272 273 const OpCodeProperty *Prop = getOpCodeProperty(OpCode); 274 if (Prop->OverloadParamIndex < 0) { 275 auto &Ctx = FT->getContext(); 276 // When only has 1 overload type, just return it. 277 switch (Prop->OverloadTys) { 278 case OverloadKind::VOID: 279 return Type::getVoidTy(Ctx); 280 case OverloadKind::HALF: 281 return Type::getHalfTy(Ctx); 282 case OverloadKind::FLOAT: 283 return Type::getFloatTy(Ctx); 284 case OverloadKind::DOUBLE: 285 return Type::getDoubleTy(Ctx); 286 case OverloadKind::I1: 287 return Type::getInt1Ty(Ctx); 288 case OverloadKind::I8: 289 return Type::getInt8Ty(Ctx); 290 case OverloadKind::I16: 291 return Type::getInt16Ty(Ctx); 292 case OverloadKind::I32: 293 return Type::getInt32Ty(Ctx); 294 case OverloadKind::I64: 295 return Type::getInt64Ty(Ctx); 296 default: 297 llvm_unreachable("invalid overload type"); 298 return nullptr; 299 } 300 } 301 302 // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType(). 303 Type *OverloadType = FT->getReturnType(); 304 if (Prop->OverloadParamIndex != 0) { 305 // Skip Return Type and Type for DXIL opcode. 306 const unsigned SkipedParam = NoOpCodeParam ? 2 : 1; 307 OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam); 308 } 309 310 auto ParamKinds = getOpCodeParameterKind(*Prop); 311 auto Kind = ParamKinds[Prop->OverloadParamIndex]; 312 // For ResRet and CBufferRet, OverloadTy is in field of StructType. 313 if (Kind == ParameterKind::CBUFFER_RET || 314 Kind == ParameterKind::RESOURCE_RET) { 315 auto *ST = cast<StructType>(OverloadType); 316 OverloadType = ST->getElementType(0); 317 } 318 return OverloadType; 319 } 320 321 const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) { 322 return ::getOpCodeName(DXILOp); 323 } 324 } // namespace dxil 325 } // namespace llvm 326