//===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// /// /// \file This file contains class to help build DXIL op functions. //===----------------------------------------------------------------------===// #include "DXILOpBuilder.h" #include "DXILConstants.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Module.h" #include "llvm/Support/DXILOperationCommon.h" #include "llvm/Support/ErrorHandling.h" using namespace llvm; using namespace llvm::DXIL; constexpr StringLiteral DXILOpNamePrefix = "dx.op."; namespace { enum OverloadKind : uint16_t { VOID = 1, HALF = 1 << 1, FLOAT = 1 << 2, DOUBLE = 1 << 3, I1 = 1 << 4, I8 = 1 << 5, I16 = 1 << 6, I32 = 1 << 7, I64 = 1 << 8, UserDefineType = 1 << 9, ObjectType = 1 << 10, }; } // namespace static const char *getOverloadTypeName(OverloadKind Kind) { switch (Kind) { case OverloadKind::HALF: return "f16"; case OverloadKind::FLOAT: return "f32"; case OverloadKind::DOUBLE: return "f64"; case OverloadKind::I1: return "i1"; case OverloadKind::I8: return "i8"; case OverloadKind::I16: return "i16"; case OverloadKind::I32: return "i32"; case OverloadKind::I64: return "i64"; case OverloadKind::VOID: case OverloadKind::ObjectType: case OverloadKind::UserDefineType: break; } llvm_unreachable("invalid overload type for name"); return "void"; } static OverloadKind getOverloadKind(Type *Ty) { Type::TypeID T = Ty->getTypeID(); switch (T) { case Type::VoidTyID: return OverloadKind::VOID; case Type::HalfTyID: return OverloadKind::HALF; case Type::FloatTyID: return OverloadKind::FLOAT; case Type::DoubleTyID: return OverloadKind::DOUBLE; case Type::IntegerTyID: { IntegerType *ITy = cast(Ty); unsigned Bits = ITy->getBitWidth(); switch (Bits) { case 1: return OverloadKind::I1; case 8: return OverloadKind::I8; case 16: return OverloadKind::I16; case 32: return OverloadKind::I32; case 64: return OverloadKind::I64; default: llvm_unreachable("invalid overload type"); return OverloadKind::VOID; } } case Type::PointerTyID: return OverloadKind::UserDefineType; case Type::StructTyID: return OverloadKind::ObjectType; default: llvm_unreachable("invalid overload type"); return OverloadKind::VOID; } } static std::string getTypeName(OverloadKind Kind, Type *Ty) { if (Kind < OverloadKind::UserDefineType) { return getOverloadTypeName(Kind); } else if (Kind == OverloadKind::UserDefineType) { StructType *ST = cast(Ty); return ST->getStructName().str(); } else if (Kind == OverloadKind::ObjectType) { StructType *ST = cast(Ty); return ST->getStructName().str(); } else { std::string Str; raw_string_ostream OS(Str); Ty->print(OS); return OS.str(); } } // Static properties. struct OpCodeProperty { DXIL::OpCode OpCode; // Offset in DXILOpCodeNameTable. unsigned OpCodeNameOffset; DXIL::OpCodeClass OpCodeClass; // Offset in DXILOpCodeClassNameTable. unsigned OpCodeClassNameOffset; uint16_t OverloadTys; llvm::Attribute::AttrKind FuncAttr; int OverloadParamIndex; // parameter index which control the overload. // When < 0, should be only 1 overload type. unsigned NumOfParameters; // Number of parameters include return value. unsigned ParameterTableOffset; // Offset in ParameterTable. }; // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and // getOpCodeParameterKind which generated by tableGen. #define DXIL_OP_OPERATION_TABLE #include "DXILOperation.inc" #undef DXIL_OP_OPERATION_TABLE static std::string constructOverloadName(OverloadKind Kind, Type *Ty, const OpCodeProperty &Prop) { if (Kind == OverloadKind::VOID) { return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str(); } return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." + getTypeName(Kind, Ty)) .str(); } static std::string constructOverloadTypeName(OverloadKind Kind, StringRef TypeName) { if (Kind == OverloadKind::VOID) return TypeName.str(); assert(Kind < OverloadKind::UserDefineType && "invalid overload kind"); return (Twine(TypeName) + getOverloadTypeName(Kind)).str(); } static StructType *getOrCreateStructType(StringRef Name, ArrayRef EltTys, LLVMContext &Ctx) { StructType *ST = StructType::getTypeByName(Ctx, Name); if (ST) return ST; return StructType::create(Ctx, EltTys, Name); } static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) { OverloadKind Kind = getOverloadKind(OverloadTy); std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet."); Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy, Type::getInt32Ty(Ctx)}; return getOrCreateStructType(TypeName, FieldTypes, Ctx); } static StructType *getHandleType(LLVMContext &Ctx) { return getOrCreateStructType("dx.types.Handle", Type::getInt8PtrTy(Ctx), Ctx); } static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) { auto &Ctx = OverloadTy->getContext(); switch (Kind) { case ParameterKind::VOID: return Type::getVoidTy(Ctx); case ParameterKind::HALF: return Type::getHalfTy(Ctx); case ParameterKind::FLOAT: return Type::getFloatTy(Ctx); case ParameterKind::DOUBLE: return Type::getDoubleTy(Ctx); case ParameterKind::I1: return Type::getInt1Ty(Ctx); case ParameterKind::I8: return Type::getInt8Ty(Ctx); case ParameterKind::I16: return Type::getInt16Ty(Ctx); case ParameterKind::I32: return Type::getInt32Ty(Ctx); case ParameterKind::I64: return Type::getInt64Ty(Ctx); case ParameterKind::OVERLOAD: return OverloadTy; case ParameterKind::RESOURCE_RET: return getResRetType(OverloadTy, Ctx); case ParameterKind::DXIL_HANDLE: return getHandleType(Ctx); default: break; } llvm_unreachable("Invalid parameter kind"); return nullptr; } static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop, Type *OverloadTy) { SmallVector ArgTys; auto ParamKinds = getOpCodeParameterKind(*Prop); for (unsigned I = 0; I < Prop->NumOfParameters; ++I) { ParameterKind Kind = ParamKinds[I]; ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy)); } return FunctionType::get( ArgTys[0], ArrayRef(&ArgTys[1], ArgTys.size() - 1), false); } static FunctionCallee getOrCreateDXILOpFunction(DXIL::OpCode DXILOp, Type *OverloadTy, Module &M) { const OpCodeProperty *Prop = getOpCodeProperty(DXILOp); OverloadKind Kind = getOverloadKind(OverloadTy); // FIXME: find the issue and report error in clang instead of check it in // backend. if ((Prop->OverloadTys & (uint16_t)Kind) == 0) { llvm_unreachable("invalid overload"); } std::string FnName = constructOverloadName(Kind, OverloadTy, *Prop); // Dependent on name to dedup. if (auto *Fn = M.getFunction(FnName)) return FunctionCallee(Fn); FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, OverloadTy); return M.getOrInsertFunction(FnName, DXILOpFT); } namespace llvm { namespace DXIL { CallInst *DXILOpBuilder::createDXILOpCall(DXIL::OpCode OpCode, Type *OverloadTy, llvm::iterator_range Args) { auto Fn = getOrCreateDXILOpFunction(OpCode, OverloadTy, M); SmallVector FullArgs; FullArgs.emplace_back(B.getInt32((int32_t)OpCode)); FullArgs.append(Args.begin(), Args.end()); return B.CreateCall(Fn, FullArgs); } Type *DXILOpBuilder::getOverloadTy(DXIL::OpCode OpCode, FunctionType *FT, bool NoOpCodeParam) { const OpCodeProperty *Prop = getOpCodeProperty(OpCode); if (Prop->OverloadParamIndex < 0) { auto &Ctx = FT->getContext(); // When only has 1 overload type, just return it. switch (Prop->OverloadTys) { case OverloadKind::VOID: return Type::getVoidTy(Ctx); case OverloadKind::HALF: return Type::getHalfTy(Ctx); case OverloadKind::FLOAT: return Type::getFloatTy(Ctx); case OverloadKind::DOUBLE: return Type::getDoubleTy(Ctx); case OverloadKind::I1: return Type::getInt1Ty(Ctx); case OverloadKind::I8: return Type::getInt8Ty(Ctx); case OverloadKind::I16: return Type::getInt16Ty(Ctx); case OverloadKind::I32: return Type::getInt32Ty(Ctx); case OverloadKind::I64: return Type::getInt64Ty(Ctx); default: llvm_unreachable("invalid overload type"); return nullptr; } } // Prop->OverloadParamIndex is 0, overload type is FT->getReturnType(). Type *OverloadType = FT->getReturnType(); if (Prop->OverloadParamIndex != 0) { // Skip Return Type and Type for DXIL opcode. const unsigned SkipedParam = NoOpCodeParam ? 2 : 1; OverloadType = FT->getParamType(Prop->OverloadParamIndex - SkipedParam); } auto ParamKinds = getOpCodeParameterKind(*Prop); auto Kind = ParamKinds[Prop->OverloadParamIndex]; // For ResRet and CBufferRet, OverloadTy is in field of StructType. if (Kind == ParameterKind::CBUFFER_RET || Kind == ParameterKind::RESOURCE_RET) { auto *ST = cast(OverloadType); OverloadType = ST->getElementType(0); } return OverloadType; } const char *DXILOpBuilder::getOpCodeName(DXIL::OpCode DXILOp) { return ::getOpCodeName(DXILOp); } } // namespace DXIL } // namespace llvm