1 //===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file This file contains class to help build DXIL op functions.
10 //===----------------------------------------------------------------------===//
11
12 #include "DXILOpBuilder.h"
13 #include "DXILConstants.h"
14 #include "llvm/IR/Module.h"
15 #include "llvm/Support/DXILABI.h"
16 #include "llvm/Support/ErrorHandling.h"
17 #include <optional>
18
19 using namespace llvm;
20 using namespace llvm::dxil;
21
22 constexpr StringLiteral DXILOpNamePrefix = "dx.op.";
23
24 namespace {
25 enum OverloadKind : uint16_t {
26 UNDEFINED = 0,
27 VOID = 1,
28 HALF = 1 << 1,
29 FLOAT = 1 << 2,
30 DOUBLE = 1 << 3,
31 I1 = 1 << 4,
32 I8 = 1 << 5,
33 I16 = 1 << 6,
34 I32 = 1 << 7,
35 I64 = 1 << 8,
36 UserDefineType = 1 << 9,
37 ObjectType = 1 << 10,
38 };
39 struct Version {
40 unsigned Major = 0;
41 unsigned Minor = 0;
42 };
43
44 struct OpOverload {
45 Version DXILVersion;
46 uint16_t ValidTys;
47 };
48 } // namespace
49
50 struct OpStage {
51 Version DXILVersion;
52 uint32_t ValidStages;
53 };
54
getOverloadTypeName(OverloadKind Kind)55 static const char *getOverloadTypeName(OverloadKind Kind) {
56 switch (Kind) {
57 case OverloadKind::HALF:
58 return "f16";
59 case OverloadKind::FLOAT:
60 return "f32";
61 case OverloadKind::DOUBLE:
62 return "f64";
63 case OverloadKind::I1:
64 return "i1";
65 case OverloadKind::I8:
66 return "i8";
67 case OverloadKind::I16:
68 return "i16";
69 case OverloadKind::I32:
70 return "i32";
71 case OverloadKind::I64:
72 return "i64";
73 case OverloadKind::VOID:
74 case OverloadKind::UNDEFINED:
75 return "void";
76 case OverloadKind::ObjectType:
77 case OverloadKind::UserDefineType:
78 break;
79 }
80 llvm_unreachable("invalid overload type for name");
81 }
82
getOverloadKind(Type * Ty)83 static OverloadKind getOverloadKind(Type *Ty) {
84 if (!Ty)
85 return OverloadKind::VOID;
86
87 Type::TypeID T = Ty->getTypeID();
88 switch (T) {
89 case Type::VoidTyID:
90 return OverloadKind::VOID;
91 case Type::HalfTyID:
92 return OverloadKind::HALF;
93 case Type::FloatTyID:
94 return OverloadKind::FLOAT;
95 case Type::DoubleTyID:
96 return OverloadKind::DOUBLE;
97 case Type::IntegerTyID: {
98 IntegerType *ITy = cast<IntegerType>(Ty);
99 unsigned Bits = ITy->getBitWidth();
100 switch (Bits) {
101 case 1:
102 return OverloadKind::I1;
103 case 8:
104 return OverloadKind::I8;
105 case 16:
106 return OverloadKind::I16;
107 case 32:
108 return OverloadKind::I32;
109 case 64:
110 return OverloadKind::I64;
111 default:
112 llvm_unreachable("invalid overload type");
113 return OverloadKind::VOID;
114 }
115 }
116 case Type::PointerTyID:
117 return OverloadKind::UserDefineType;
118 case Type::StructTyID: {
119 // TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework
120 // how we're handling overloads and remove the `OverloadKind` proxy enum.
121 StructType *ST = cast<StructType>(Ty);
122 return getOverloadKind(ST->getElementType(0));
123 }
124 default:
125 return OverloadKind::UNDEFINED;
126 }
127 }
128
getTypeName(OverloadKind Kind,Type * Ty)129 static std::string getTypeName(OverloadKind Kind, Type *Ty) {
130 if (Kind < OverloadKind::UserDefineType) {
131 return getOverloadTypeName(Kind);
132 } else if (Kind == OverloadKind::UserDefineType) {
133 StructType *ST = cast<StructType>(Ty);
134 return ST->getStructName().str();
135 } else if (Kind == OverloadKind::ObjectType) {
136 StructType *ST = cast<StructType>(Ty);
137 return ST->getStructName().str();
138 } else {
139 std::string Str;
140 raw_string_ostream OS(Str);
141 Ty->print(OS);
142 return OS.str();
143 }
144 }
145
146 // Static properties.
147 struct OpCodeProperty {
148 dxil::OpCode OpCode;
149 // Offset in DXILOpCodeNameTable.
150 unsigned OpCodeNameOffset;
151 dxil::OpCodeClass OpCodeClass;
152 // Offset in DXILOpCodeClassNameTable.
153 unsigned OpCodeClassNameOffset;
154 llvm::SmallVector<OpOverload> Overloads;
155 llvm::SmallVector<OpStage> Stages;
156 int OverloadParamIndex; // parameter index which control the overload.
157 // When < 0, should be only 1 overload type.
158 };
159
160 // Include getOpCodeClassName getOpCodeProperty, getOpCodeName and
161 // getOpCodeParameterKind which generated by tableGen.
162 #define DXIL_OP_OPERATION_TABLE
163 #include "DXILOperation.inc"
164 #undef DXIL_OP_OPERATION_TABLE
165
constructOverloadName(OverloadKind Kind,Type * Ty,const OpCodeProperty & Prop)166 static std::string constructOverloadName(OverloadKind Kind, Type *Ty,
167 const OpCodeProperty &Prop) {
168 if (Kind == OverloadKind::VOID) {
169 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();
170 }
171 return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +
172 getTypeName(Kind, Ty))
173 .str();
174 }
175
constructOverloadTypeName(OverloadKind Kind,StringRef TypeName)176 static std::string constructOverloadTypeName(OverloadKind Kind,
177 StringRef TypeName) {
178 if (Kind == OverloadKind::VOID)
179 return TypeName.str();
180
181 assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");
182 return (Twine(TypeName) + getOverloadTypeName(Kind)).str();
183 }
184
getOrCreateStructType(StringRef Name,ArrayRef<Type * > EltTys,LLVMContext & Ctx)185 static StructType *getOrCreateStructType(StringRef Name,
186 ArrayRef<Type *> EltTys,
187 LLVMContext &Ctx) {
188 StructType *ST = StructType::getTypeByName(Ctx, Name);
189 if (ST)
190 return ST;
191
192 return StructType::create(Ctx, EltTys, Name);
193 }
194
getResRetType(Type * ElementTy)195 static StructType *getResRetType(Type *ElementTy) {
196 LLVMContext &Ctx = ElementTy->getContext();
197 OverloadKind Kind = getOverloadKind(ElementTy);
198 std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");
199 Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy,
200 Type::getInt32Ty(Ctx)};
201 return getOrCreateStructType(TypeName, FieldTypes, Ctx);
202 }
203
getCBufRetType(Type * ElementTy)204 static StructType *getCBufRetType(Type *ElementTy) {
205 LLVMContext &Ctx = ElementTy->getContext();
206 OverloadKind Kind = getOverloadKind(ElementTy);
207 std::string TypeName = constructOverloadTypeName(Kind, "dx.types.CBufRet.");
208
209 // 64-bit types only have two elements
210 if (ElementTy->isDoubleTy() || ElementTy->isIntegerTy(64))
211 return getOrCreateStructType(TypeName, {ElementTy, ElementTy}, Ctx);
212
213 // 16-bit types pack 8 elements and have .8 in their name to differentiate
214 // from min-precision types.
215 if (ElementTy->isHalfTy() || ElementTy->isIntegerTy(16)) {
216 TypeName += ".8";
217 return getOrCreateStructType(TypeName,
218 {ElementTy, ElementTy, ElementTy, ElementTy,
219 ElementTy, ElementTy, ElementTy, ElementTy},
220 Ctx);
221 }
222
223 return getOrCreateStructType(
224 TypeName, {ElementTy, ElementTy, ElementTy, ElementTy}, Ctx);
225 }
226
getHandleType(LLVMContext & Ctx)227 static StructType *getHandleType(LLVMContext &Ctx) {
228 return getOrCreateStructType("dx.types.Handle", PointerType::getUnqual(Ctx),
229 Ctx);
230 }
231
getResBindType(LLVMContext & Context)232 static StructType *getResBindType(LLVMContext &Context) {
233 if (auto *ST = StructType::getTypeByName(Context, "dx.types.ResBind"))
234 return ST;
235 Type *Int32Ty = Type::getInt32Ty(Context);
236 Type *Int8Ty = Type::getInt8Ty(Context);
237 return StructType::create({Int32Ty, Int32Ty, Int32Ty, Int8Ty},
238 "dx.types.ResBind");
239 }
240
getResPropsType(LLVMContext & Context)241 static StructType *getResPropsType(LLVMContext &Context) {
242 if (auto *ST =
243 StructType::getTypeByName(Context, "dx.types.ResourceProperties"))
244 return ST;
245 Type *Int32Ty = Type::getInt32Ty(Context);
246 return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties");
247 }
248
getSplitDoubleType(LLVMContext & Context)249 static StructType *getSplitDoubleType(LLVMContext &Context) {
250 if (auto *ST = StructType::getTypeByName(Context, "dx.types.splitdouble"))
251 return ST;
252 Type *Int32Ty = Type::getInt32Ty(Context);
253 return StructType::create({Int32Ty, Int32Ty}, "dx.types.splitdouble");
254 }
255
getBinaryWithCarryType(LLVMContext & Context)256 static StructType *getBinaryWithCarryType(LLVMContext &Context) {
257 if (auto *ST = StructType::getTypeByName(Context, "dx.types.i32c"))
258 return ST;
259 Type *Int32Ty = Type::getInt32Ty(Context);
260 Type *Int1Ty = Type::getInt1Ty(Context);
261 return StructType::create({Int32Ty, Int1Ty}, "dx.types.i32c");
262 }
263
getTypeFromOpParamType(OpParamType Kind,LLVMContext & Ctx,Type * OverloadTy)264 static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx,
265 Type *OverloadTy) {
266 switch (Kind) {
267 case OpParamType::VoidTy:
268 return Type::getVoidTy(Ctx);
269 case OpParamType::HalfTy:
270 return Type::getHalfTy(Ctx);
271 case OpParamType::FloatTy:
272 return Type::getFloatTy(Ctx);
273 case OpParamType::DoubleTy:
274 return Type::getDoubleTy(Ctx);
275 case OpParamType::Int1Ty:
276 return Type::getInt1Ty(Ctx);
277 case OpParamType::Int8Ty:
278 return Type::getInt8Ty(Ctx);
279 case OpParamType::Int16Ty:
280 return Type::getInt16Ty(Ctx);
281 case OpParamType::Int32Ty:
282 return Type::getInt32Ty(Ctx);
283 case OpParamType::Int64Ty:
284 return Type::getInt64Ty(Ctx);
285 case OpParamType::OverloadTy:
286 return OverloadTy;
287 case OpParamType::ResRetHalfTy:
288 return getResRetType(Type::getHalfTy(Ctx));
289 case OpParamType::ResRetFloatTy:
290 return getResRetType(Type::getFloatTy(Ctx));
291 case OpParamType::ResRetDoubleTy:
292 return getResRetType(Type::getDoubleTy(Ctx));
293 case OpParamType::ResRetInt16Ty:
294 return getResRetType(Type::getInt16Ty(Ctx));
295 case OpParamType::ResRetInt32Ty:
296 return getResRetType(Type::getInt32Ty(Ctx));
297 case OpParamType::ResRetInt64Ty:
298 return getResRetType(Type::getInt64Ty(Ctx));
299 case OpParamType::CBufRetHalfTy:
300 return getCBufRetType(Type::getHalfTy(Ctx));
301 case OpParamType::CBufRetFloatTy:
302 return getCBufRetType(Type::getFloatTy(Ctx));
303 case OpParamType::CBufRetDoubleTy:
304 return getCBufRetType(Type::getDoubleTy(Ctx));
305 case OpParamType::CBufRetInt16Ty:
306 return getCBufRetType(Type::getInt16Ty(Ctx));
307 case OpParamType::CBufRetInt32Ty:
308 return getCBufRetType(Type::getInt32Ty(Ctx));
309 case OpParamType::CBufRetInt64Ty:
310 return getCBufRetType(Type::getInt64Ty(Ctx));
311 case OpParamType::HandleTy:
312 return getHandleType(Ctx);
313 case OpParamType::ResBindTy:
314 return getResBindType(Ctx);
315 case OpParamType::ResPropsTy:
316 return getResPropsType(Ctx);
317 case OpParamType::SplitDoubleTy:
318 return getSplitDoubleType(Ctx);
319 case OpParamType::BinaryWithCarryTy:
320 return getBinaryWithCarryType(Ctx);
321 }
322 llvm_unreachable("Invalid parameter kind");
323 return nullptr;
324 }
325
getShaderKindEnum(Triple::EnvironmentType EnvType)326 static ShaderKind getShaderKindEnum(Triple::EnvironmentType EnvType) {
327 switch (EnvType) {
328 case Triple::Pixel:
329 return ShaderKind::pixel;
330 case Triple::Vertex:
331 return ShaderKind::vertex;
332 case Triple::Geometry:
333 return ShaderKind::geometry;
334 case Triple::Hull:
335 return ShaderKind::hull;
336 case Triple::Domain:
337 return ShaderKind::domain;
338 case Triple::Compute:
339 return ShaderKind::compute;
340 case Triple::Library:
341 return ShaderKind::library;
342 case Triple::RayGeneration:
343 return ShaderKind::raygeneration;
344 case Triple::Intersection:
345 return ShaderKind::intersection;
346 case Triple::AnyHit:
347 return ShaderKind::anyhit;
348 case Triple::ClosestHit:
349 return ShaderKind::closesthit;
350 case Triple::Miss:
351 return ShaderKind::miss;
352 case Triple::Callable:
353 return ShaderKind::callable;
354 case Triple::Mesh:
355 return ShaderKind::mesh;
356 case Triple::Amplification:
357 return ShaderKind::amplification;
358 default:
359 break;
360 }
361 llvm_unreachable(
362 "Shader Kind Not Found - Invalid DXIL Environment Specified");
363 }
364
365 static SmallVector<Type *>
getArgTypesFromOpParamTypes(ArrayRef<dxil::OpParamType> Types,LLVMContext & Context,Type * OverloadTy)366 getArgTypesFromOpParamTypes(ArrayRef<dxil::OpParamType> Types,
367 LLVMContext &Context, Type *OverloadTy) {
368 SmallVector<Type *> ArgTys;
369 ArgTys.emplace_back(Type::getInt32Ty(Context));
370 for (dxil::OpParamType Ty : Types)
371 ArgTys.emplace_back(getTypeFromOpParamType(Ty, Context, OverloadTy));
372 return ArgTys;
373 }
374
375 /// Construct DXIL function type. This is the type of a function with
376 /// the following prototype
377 /// OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)
378 /// <param-types> are constructed from types in Prop.
getDXILOpFunctionType(dxil::OpCode OpCode,LLVMContext & Context,Type * OverloadTy)379 static FunctionType *getDXILOpFunctionType(dxil::OpCode OpCode,
380 LLVMContext &Context,
381 Type *OverloadTy) {
382
383 switch (OpCode) {
384 #define DXIL_OP_FUNCTION_TYPE(OpCode, RetType, ...) \
385 case OpCode: \
386 return FunctionType::get( \
387 getTypeFromOpParamType(RetType, Context, OverloadTy), \
388 getArgTypesFromOpParamTypes({__VA_ARGS__}, Context, OverloadTy), \
389 /*isVarArg=*/false);
390 #include "DXILOperation.inc"
391 }
392 llvm_unreachable("Invalid OpCode?");
393 }
394
395 /// Get index of the property from PropList valid for the most recent
396 /// DXIL version not greater than DXILVer.
397 /// PropList is expected to be sorted in ascending order of DXIL version.
398 template <typename T>
getPropIndex(ArrayRef<T> PropList,const VersionTuple DXILVer)399 static std::optional<size_t> getPropIndex(ArrayRef<T> PropList,
400 const VersionTuple DXILVer) {
401 size_t Index = PropList.size() - 1;
402 for (auto Iter = PropList.rbegin(); Iter != PropList.rend();
403 Iter++, Index--) {
404 const T &Prop = *Iter;
405 if (VersionTuple(Prop.DXILVersion.Major, Prop.DXILVersion.Minor) <=
406 DXILVer) {
407 return Index;
408 }
409 }
410 return std::nullopt;
411 }
412
413 // Helper function to pack an OpCode and VersionTuple into a uint64_t for use
414 // in a switch statement
computeSwitchEnum(dxil::OpCode OpCode,uint16_t VersionMajor,uint16_t VersionMinor)415 constexpr static uint64_t computeSwitchEnum(dxil::OpCode OpCode,
416 uint16_t VersionMajor,
417 uint16_t VersionMinor) {
418 uint64_t OpCodePack = (uint64_t)OpCode;
419 return (OpCodePack << 32) | (VersionMajor << 16) | VersionMinor;
420 }
421
422 // Retreive all the set attributes for a DXIL OpCode given the targeted
423 // DXILVersion
getDXILAttributes(dxil::OpCode OpCode,VersionTuple DXILVersion)424 static dxil::Attributes getDXILAttributes(dxil::OpCode OpCode,
425 VersionTuple DXILVersion) {
426 // Instantiate all versions to iterate through
427 SmallVector<Version> Versions = {
428 #define DXIL_VERSION(MAJOR, MINOR) {MAJOR, MINOR},
429 #include "DXILOperation.inc"
430 };
431
432 dxil::Attributes Attributes;
433 for (auto Version : Versions) {
434 if (DXILVersion < VersionTuple(Version.Major, Version.Minor))
435 continue;
436
437 // Switch through and match an OpCode with the specific version and set the
438 // corresponding flag(s) if available
439 switch (computeSwitchEnum(OpCode, Version.Major, Version.Minor)) {
440 #define DXIL_OP_ATTRIBUTES(OpCode, VersionMajor, VersionMinor, ...) \
441 case computeSwitchEnum(OpCode, VersionMajor, VersionMinor): { \
442 auto Other = dxil::Attributes{__VA_ARGS__}; \
443 Attributes |= Other; \
444 break; \
445 };
446 #include "DXILOperation.inc"
447 }
448 }
449 return Attributes;
450 }
451
452 // Retreive the set of DXIL Attributes given the version and map them to an
453 // llvm function attribute that is set onto the instruction
setDXILAttributes(CallInst * CI,dxil::OpCode OpCode,VersionTuple DXILVersion)454 static void setDXILAttributes(CallInst *CI, dxil::OpCode OpCode,
455 VersionTuple DXILVersion) {
456 dxil::Attributes Attributes = getDXILAttributes(OpCode, DXILVersion);
457 if (Attributes.ReadNone)
458 CI->setDoesNotAccessMemory();
459 if (Attributes.ReadOnly)
460 CI->setOnlyReadsMemory();
461 if (Attributes.NoReturn)
462 CI->setDoesNotReturn();
463 if (Attributes.NoDuplicate)
464 CI->setCannotDuplicate();
465 return;
466 }
467
468 namespace llvm {
469 namespace dxil {
470
471 // No extra checks on TargetTriple need be performed to verify that the
472 // Triple is well-formed or that the target is supported since these checks
473 // would have been done at the time the module M is constructed in the earlier
474 // stages of compilation.
DXILOpBuilder(Module & M)475 DXILOpBuilder::DXILOpBuilder(Module &M) : M(M), IRB(M.getContext()) {
476 const Triple &TT = M.getTargetTriple();
477 DXILVersion = TT.getDXILVersion();
478 ShaderStage = TT.getEnvironment();
479 // Ensure Environment type is known
480 if (ShaderStage == Triple::UnknownEnvironment) {
481 reportFatalUsageError(
482 Twine(DXILVersion.getAsString()) +
483 ": Unknown Compilation Target Shader Stage specified ");
484 }
485 }
486
makeOpError(dxil::OpCode OpCode,Twine Msg)487 static Error makeOpError(dxil::OpCode OpCode, Twine Msg) {
488 return make_error<StringError>(
489 Twine("Cannot create ") + getOpCodeName(OpCode) + " operation: " + Msg,
490 inconvertibleErrorCode());
491 }
492
tryCreateOp(dxil::OpCode OpCode,ArrayRef<Value * > Args,const Twine & Name,Type * RetTy)493 Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode,
494 ArrayRef<Value *> Args,
495 const Twine &Name,
496 Type *RetTy) {
497 const OpCodeProperty *Prop = getOpCodeProperty(OpCode);
498
499 Type *OverloadTy = nullptr;
500 if (Prop->OverloadParamIndex == 0) {
501 if (!RetTy)
502 return makeOpError(OpCode, "Op overloaded on unknown return type");
503 OverloadTy = RetTy;
504 } else if (Prop->OverloadParamIndex > 0) {
505 // The index counts including the return type
506 unsigned ArgIndex = Prop->OverloadParamIndex - 1;
507 if (static_cast<unsigned>(ArgIndex) >= Args.size())
508 return makeOpError(OpCode, "Wrong number of arguments");
509 OverloadTy = Args[ArgIndex]->getType();
510 }
511
512 FunctionType *DXILOpFT =
513 getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy);
514
515 std::optional<size_t> OlIndexOrErr =
516 getPropIndex(ArrayRef(Prop->Overloads), DXILVersion);
517 if (!OlIndexOrErr.has_value())
518 return makeOpError(OpCode, Twine("No valid overloads for DXIL version ") +
519 DXILVersion.getAsString());
520
521 uint16_t ValidTyMask = Prop->Overloads[*OlIndexOrErr].ValidTys;
522
523 OverloadKind Kind = getOverloadKind(OverloadTy);
524
525 // Check if the operation supports overload types and OverloadTy is valid
526 // per the specified types for the operation
527 if ((ValidTyMask != OverloadKind::UNDEFINED) &&
528 (ValidTyMask & (uint16_t)Kind) == 0)
529 return makeOpError(OpCode, "Invalid overload type");
530
531 // Perform necessary checks to ensure Opcode is valid in the targeted shader
532 // kind
533 std::optional<size_t> StIndexOrErr =
534 getPropIndex(ArrayRef(Prop->Stages), DXILVersion);
535 if (!StIndexOrErr.has_value())
536 return makeOpError(OpCode, Twine("No valid stage for DXIL version ") +
537 DXILVersion.getAsString());
538
539 uint16_t ValidShaderKindMask = Prop->Stages[*StIndexOrErr].ValidStages;
540
541 // Ensure valid shader stage properties are specified
542 if (ValidShaderKindMask == ShaderKind::removed)
543 return makeOpError(OpCode, "Operation has been removed");
544
545 // Shader stage need not be validated since getShaderKindEnum() fails
546 // for unknown shader stage.
547
548 // Verify the target shader stage is valid for the DXIL operation
549 ShaderKind ModuleStagekind = getShaderKindEnum(ShaderStage);
550 if (!(ValidShaderKindMask & ModuleStagekind))
551 return makeOpError(OpCode, "Invalid stage");
552
553 std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop);
554 FunctionCallee DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);
555
556 // We need to inject the opcode as the first argument.
557 SmallVector<Value *> OpArgs;
558 OpArgs.push_back(IRB.getInt32(llvm::to_underlying(OpCode)));
559 OpArgs.append(Args.begin(), Args.end());
560
561 // Create the function call instruction
562 CallInst *CI = IRB.CreateCall(DXILFn, OpArgs, Name);
563
564 // We then need to attach available function attributes
565 setDXILAttributes(CI, OpCode, DXILVersion);
566
567 return CI;
568 }
569
createOp(dxil::OpCode OpCode,ArrayRef<Value * > Args,const Twine & Name,Type * RetTy)570 CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args,
571 const Twine &Name, Type *RetTy) {
572 Expected<CallInst *> Result = tryCreateOp(OpCode, Args, Name, RetTy);
573 if (Error E = Result.takeError())
574 llvm_unreachable("Invalid arguments for operation");
575 return *Result;
576 }
577
getResRetType(Type * ElementTy)578 StructType *DXILOpBuilder::getResRetType(Type *ElementTy) {
579 return ::getResRetType(ElementTy);
580 }
581
getCBufRetType(Type * ElementTy)582 StructType *DXILOpBuilder::getCBufRetType(Type *ElementTy) {
583 return ::getCBufRetType(ElementTy);
584 }
585
getHandleType()586 StructType *DXILOpBuilder::getHandleType() {
587 return ::getHandleType(IRB.getContext());
588 }
589
getResBind(uint32_t LowerBound,uint32_t UpperBound,uint32_t SpaceID,dxil::ResourceClass RC)590 Constant *DXILOpBuilder::getResBind(uint32_t LowerBound, uint32_t UpperBound,
591 uint32_t SpaceID, dxil::ResourceClass RC) {
592 Type *Int32Ty = IRB.getInt32Ty();
593 Type *Int8Ty = IRB.getInt8Ty();
594 return ConstantStruct::get(
595 getResBindType(IRB.getContext()),
596 {ConstantInt::get(Int32Ty, LowerBound),
597 ConstantInt::get(Int32Ty, UpperBound),
598 ConstantInt::get(Int32Ty, SpaceID),
599 ConstantInt::get(Int8Ty, llvm::to_underlying(RC))});
600 }
601
getResProps(uint32_t Word0,uint32_t Word1)602 Constant *DXILOpBuilder::getResProps(uint32_t Word0, uint32_t Word1) {
603 Type *Int32Ty = IRB.getInt32Ty();
604 return ConstantStruct::get(
605 getResPropsType(IRB.getContext()),
606 {ConstantInt::get(Int32Ty, Word0), ConstantInt::get(Int32Ty, Word1)});
607 }
608
getOpCodeName(dxil::OpCode DXILOp)609 const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) {
610 return ::getOpCodeName(DXILOp);
611 }
612 } // namespace dxil
613 } // namespace llvm
614