xref: /freebsd/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVRegularizer.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass implements regularization of LLVM IR for SPIR-V. The prototype of
10 // the pass was taken from SPIRV-LLVM translator.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "SPIRV.h"
15 #include "llvm/Demangle/Demangle.h"
16 #include "llvm/IR/InstIterator.h"
17 #include "llvm/IR/InstVisitor.h"
18 #include "llvm/IR/PassManager.h"
19 #include "llvm/Transforms/Utils/Cloning.h"
20 
21 #include <list>
22 
23 #define DEBUG_TYPE "spirv-regularizer"
24 
25 using namespace llvm;
26 
27 namespace {
28 struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
29   DenseMap<Function *, Function *> Old2NewFuncs;
30 
31 public:
32   static char ID;
SPIRVRegularizer__anonbaf691690111::SPIRVRegularizer33   SPIRVRegularizer() : FunctionPass(ID) {}
34   bool runOnFunction(Function &F) override;
getPassName__anonbaf691690111::SPIRVRegularizer35   StringRef getPassName() const override { return "SPIR-V Regularizer"; }
36 
getAnalysisUsage__anonbaf691690111::SPIRVRegularizer37   void getAnalysisUsage(AnalysisUsage &AU) const override {
38     FunctionPass::getAnalysisUsage(AU);
39   }
40   void visitCallInst(CallInst &CI);
41 
42 private:
43   void visitCallScalToVec(CallInst *CI, StringRef MangledName,
44                           StringRef DemangledName);
45   void runLowerConstExpr(Function &F);
46 };
47 } // namespace
48 
49 char SPIRVRegularizer::ID = 0;
50 
51 INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,
52                 false)
53 
54 // Since SPIR-V cannot represent constant expression, constant expressions
55 // in LLVM IR need to be lowered to instructions. For each function,
56 // the constant expressions used by instructions of the function are replaced
57 // by instructions placed in the entry block since it dominates all other BBs.
58 // Each constant expression only needs to be lowered once in each function
59 // and all uses of it by instructions in that function are replaced by
60 // one instruction.
61 // TODO: remove redundant instructions for common subexpression.
runLowerConstExpr(Function & F)62 void SPIRVRegularizer::runLowerConstExpr(Function &F) {
63   LLVMContext &Ctx = F.getContext();
64   std::list<Instruction *> WorkList;
65   for (auto &II : instructions(F))
66     WorkList.push_back(&II);
67 
68   auto FBegin = F.begin();
69   while (!WorkList.empty()) {
70     Instruction *II = WorkList.front();
71 
72     auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {
73       if (isa<Function>(V))
74         return V;
75       auto *CE = cast<ConstantExpr>(V);
76       LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);
77       auto ReplInst = CE->getAsInstruction();
78       auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
79       ReplInst->insertBefore(InsPoint->getIterator());
80       LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');
81       std::vector<Instruction *> Users;
82       // Do not replace use during iteration of use. Do it in another loop.
83       for (auto U : CE->users()) {
84         LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');
85         auto InstUser = dyn_cast<Instruction>(U);
86         // Only replace users in scope of current function.
87         if (InstUser && InstUser->getParent()->getParent() == &F)
88           Users.push_back(InstUser);
89       }
90       for (auto &User : Users) {
91         if (ReplInst->getParent() == User->getParent() &&
92             User->comesBefore(ReplInst))
93           ReplInst->moveBefore(User->getIterator());
94         User->replaceUsesOfWith(CE, ReplInst);
95       }
96       return ReplInst;
97     };
98 
99     WorkList.pop_front();
100     auto LowerConstantVec = [&II, &LowerOp, &WorkList,
101                              &Ctx](ConstantVector *Vec,
102                                    unsigned NumOfOp) -> Value * {
103       if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
104             return isa<ConstantExpr>(V) || isa<Function>(V);
105           })) {
106         // Expand a vector of constexprs and construct it back with
107         // series of insertelement instructions.
108         std::list<Value *> OpList;
109         std::transform(Vec->op_begin(), Vec->op_end(),
110                        std::back_inserter(OpList),
111                        [LowerOp](Value *V) { return LowerOp(V); });
112         Value *Repl = nullptr;
113         unsigned Idx = 0;
114         auto *PhiII = dyn_cast<PHINode>(II);
115         Instruction *InsPoint =
116             PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
117         std::list<Instruction *> ReplList;
118         for (auto V : OpList) {
119           if (auto *Inst = dyn_cast<Instruction>(V))
120             ReplList.push_back(Inst);
121           Repl = InsertElementInst::Create(
122               (Repl ? Repl : PoisonValue::get(Vec->getType())), V,
123               ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "",
124               InsPoint->getIterator());
125         }
126         WorkList.splice(WorkList.begin(), ReplList);
127         return Repl;
128       }
129       return nullptr;
130     };
131     for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
132       auto *Op = II->getOperand(OI);
133       if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
134         Value *ReplInst = LowerConstantVec(Vec, OI);
135         if (ReplInst)
136           II->replaceUsesOfWith(Op, ReplInst);
137       } else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
138         WorkList.push_front(cast<Instruction>(LowerOp(CE)));
139       } else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
140         auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata());
141         if (!ConstMD)
142           continue;
143         Constant *C = ConstMD->getValue();
144         Value *ReplInst = nullptr;
145         if (auto *Vec = dyn_cast<ConstantVector>(C))
146           ReplInst = LowerConstantVec(Vec, OI);
147         if (auto *CE = dyn_cast<ConstantExpr>(C))
148           ReplInst = LowerOp(CE);
149         if (!ReplInst)
150           continue;
151         Metadata *RepMD = ValueAsMetadata::get(ReplInst);
152         Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD);
153         II->setOperand(OI, RepMDVal);
154         WorkList.push_front(cast<Instruction>(ReplInst));
155       }
156     }
157   }
158 }
159 
160 // It fixes calls to OCL builtins that accept vector arguments and one of them
161 // is actually a scalar splat.
visitCallInst(CallInst & CI)162 void SPIRVRegularizer::visitCallInst(CallInst &CI) {
163   auto F = CI.getCalledFunction();
164   if (!F)
165     return;
166 
167   auto MangledName = F->getName();
168   char *NameStr = itaniumDemangle(F->getName().data());
169   if (!NameStr)
170     return;
171   StringRef DemangledName(NameStr);
172 
173   // TODO: add support for other builtins.
174   if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") ||
175       DemangledName.starts_with("min") || DemangledName.starts_with("max"))
176     visitCallScalToVec(&CI, MangledName, DemangledName);
177   free(NameStr);
178 }
179 
visitCallScalToVec(CallInst * CI,StringRef MangledName,StringRef DemangledName)180 void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,
181                                           StringRef DemangledName) {
182   // Check if all arguments have the same type - it's simple case.
183   auto Uniform = true;
184   Type *Arg0Ty = CI->getOperand(0)->getType();
185   auto IsArg0Vector = isa<VectorType>(Arg0Ty);
186   for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I)
187     Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;
188   if (Uniform)
189     return;
190 
191   auto *OldF = CI->getCalledFunction();
192   Function *NewF = nullptr;
193   auto [It, Inserted] = Old2NewFuncs.try_emplace(OldF);
194   if (Inserted) {
195     AttributeList Attrs = CI->getCalledFunction()->getAttributes();
196     SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty};
197     auto *NewFTy =
198         FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg());
199     NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(),
200                             *OldF->getParent());
201     ValueToValueMapTy VMap;
202     auto NewFArgIt = NewF->arg_begin();
203     for (auto &Arg : OldF->args()) {
204       auto ArgName = Arg.getName();
205       NewFArgIt->setName(ArgName);
206       VMap[&Arg] = &(*NewFArgIt++);
207     }
208     SmallVector<ReturnInst *, 8> Returns;
209     CloneFunctionInto(NewF, OldF, VMap,
210                       CloneFunctionChangeType::LocalChangesOnly, Returns);
211     NewF->setAttributes(Attrs);
212     It->second = NewF;
213   } else {
214     NewF = It->second;
215   }
216   assert(NewF);
217 
218   // This produces an instruction sequence that implements a splat of
219   // CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst
220   // and ShuffleVectorInst to generate the same code as the SPIR-V translator.
221   // For instance (transcoding/OpMin.ll), this call
222   //   call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)
223   // is translated to
224   //    %8 = OpUndef %v2uint
225   //   %14 = OpConstantComposite %v2uint %uint_1 %uint_10
226   //   ...
227   //   %10 = OpCompositeInsert %v2uint %uint_5 %8 0
228   //   %11 = OpVectorShuffle %v2uint %10 %8 0 0
229   // %call = OpExtInst %v2uint %1 s_min %14 %11
230   auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);
231   PoisonValue *PVal = PoisonValue::get(Arg0Ty);
232   Instruction *Inst = InsertElementInst::Create(
233       PVal, CI->getOperand(1), ConstInt, "", CI->getIterator());
234   ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount();
235   Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt);
236   Value *NewVec =
237       new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI->getIterator());
238   CI->setOperand(1, NewVec);
239   CI->replaceUsesOfWith(OldF, NewF);
240   CI->mutateFunctionType(NewF->getFunctionType());
241 }
242 
runOnFunction(Function & F)243 bool SPIRVRegularizer::runOnFunction(Function &F) {
244   runLowerConstExpr(F);
245   visit(F);
246   for (auto &OldNew : Old2NewFuncs) {
247     Function *OldF = OldNew.first;
248     Function *NewF = OldNew.second;
249     NewF->takeName(OldF);
250     OldF->eraseFromParent();
251   }
252   return true;
253 }
254 
createSPIRVRegularizerPass()255 FunctionPass *llvm::createSPIRVRegularizerPass() {
256   return new SPIRVRegularizer();
257 }
258