xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/ReplaceWithVeclib.cpp (revision 1db9f3b21e39176dd5b67cf8ac378633b172463e)
1 //=== ReplaceWithVeclib.cpp - Replace vector intrinsics with veclib calls -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Replaces LLVM IR instructions with vector operands (i.e., the frem
10 // instruction or calls to LLVM intrinsics) with matching calls to functions
11 // from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/CodeGen/ReplaceWithVeclib.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/Analysis/DemandedBits.h"
20 #include "llvm/Analysis/GlobalsModRef.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Analysis/TargetLibraryInfo.h"
23 #include "llvm/Analysis/VectorUtils.h"
24 #include "llvm/CodeGen/Passes.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/InstIterator.h"
28 #include "llvm/Support/TypeSize.h"
29 #include "llvm/Transforms/Utils/ModuleUtils.h"
30 
31 using namespace llvm;
32 
33 #define DEBUG_TYPE "replace-with-veclib"
34 
35 STATISTIC(NumCallsReplaced,
36           "Number of calls to intrinsics that have been replaced.");
37 
38 STATISTIC(NumTLIFuncDeclAdded,
39           "Number of vector library function declarations added.");
40 
41 STATISTIC(NumFuncUsedAdded,
42           "Number of functions added to `llvm.compiler.used`");
43 
44 /// Returns a vector Function that it adds to the Module \p M. When an \p
45 /// ScalarFunc is not null, it copies its attributes to the newly created
46 /// Function.
47 Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
48                          const StringRef TLIName,
49                          Function *ScalarFunc = nullptr) {
50   Function *TLIFunc = M->getFunction(TLIName);
51   if (!TLIFunc) {
52     TLIFunc =
53         Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M);
54     if (ScalarFunc)
55       TLIFunc->copyAttributesFrom(ScalarFunc);
56 
57     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
58                       << TLIName << "` of type `" << *(TLIFunc->getType())
59                       << "` to module.\n");
60 
61     ++NumTLIFuncDeclAdded;
62     // Add the freshly created function to llvm.compiler.used, similar to as it
63     // is done in InjectTLIMappings.
64     appendToCompilerUsed(*M, {TLIFunc});
65     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
66                       << "` to `@llvm.compiler.used`.\n");
67     ++NumFuncUsedAdded;
68   }
69   return TLIFunc;
70 }
71 
72 /// Replace the instruction \p I with a call to the corresponding function from
73 /// the vector library (\p TLIVecFunc).
74 static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
75                                    Function *TLIVecFunc) {
76   IRBuilder<> IRBuilder(&I);
77   auto *CI = dyn_cast<CallInst>(&I);
78   SmallVector<Value *> Args(CI ? CI->args() : I.operands());
79   if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
80     auto *MaskTy =
81         VectorType::get(Type::getInt1Ty(I.getContext()), Info.Shape.VF);
82     Args.insert(Args.begin() + OptMaskpos.value(),
83                 Constant::getAllOnesValue(MaskTy));
84   }
85 
86   // If it is a call instruction, preserve the operand bundles.
87   SmallVector<OperandBundleDef, 1> OpBundles;
88   if (CI)
89     CI->getOperandBundlesAsDefs(OpBundles);
90 
91   auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
92   I.replaceAllUsesWith(Replacement);
93   // Preserve fast math flags for FP math.
94   if (isa<FPMathOperator>(Replacement))
95     Replacement->copyFastMathFlags(&I);
96 }
97 
98 /// Returns true when successfully replaced \p I with a suitable function taking
99 /// vector arguments, based on available mappings in the \p TLI. Currently only
100 /// works when \p I is a call to vectorized intrinsic or the frem instruction.
101 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
102                                     Instruction &I) {
103   // At the moment VFABI assumes the return type is always widened unless it is
104   // a void type.
105   auto *VTy = dyn_cast<VectorType>(I.getType());
106   ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
107 
108   // Compute the argument types of the corresponding scalar call and the scalar
109   // function name. For calls, it additionally finds the function to replace
110   // and checks that all vector operands match the previously found EC.
111   SmallVector<Type *, 8> ScalarArgTypes;
112   std::string ScalarName;
113   Function *FuncToReplace = nullptr;
114   if (auto *CI = dyn_cast<CallInst>(&I)) {
115     FuncToReplace = CI->getCalledFunction();
116     Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
117     assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
118     for (auto Arg : enumerate(CI->args())) {
119       auto *ArgTy = Arg.value()->getType();
120       if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
121         ScalarArgTypes.push_back(ArgTy);
122       } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
123         ScalarArgTypes.push_back(VectorArgTy->getElementType());
124         // When return type is void, set EC to the first vector argument, and
125         // disallow vector arguments with different ECs.
126         if (EC.isZero())
127           EC = VectorArgTy->getElementCount();
128         else if (EC != VectorArgTy->getElementCount())
129           return false;
130       } else
131         // Exit when it is supposed to be a vector argument but it isn't.
132         return false;
133     }
134     // Try to reconstruct the name for the scalar version of the instruction,
135     // using scalar argument types.
136     ScalarName = Intrinsic::isOverloaded(IID)
137                      ? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
138                      : Intrinsic::getName(IID).str();
139   } else {
140     assert(VTy && "Return type must be a vector");
141     auto *ScalarTy = VTy->getScalarType();
142     LibFunc Func;
143     if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
144       return false;
145     ScalarName = TLI.getName(Func);
146     ScalarArgTypes = {ScalarTy, ScalarTy};
147   }
148 
149   // Try to find the mapping for the scalar version of this intrinsic and the
150   // exact vector width of the call operands in the TargetLibraryInfo. First,
151   // check with a non-masked variant, and if that fails try with a masked one.
152   const VecDesc *VD =
153       TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ false);
154   if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ true)))
155     return false;
156 
157   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName
158                     << "` and vector width " << EC << " to: `"
159                     << VD->getVectorFnName() << "`.\n");
160 
161   // Replace the call to the intrinsic with a call to the vector library
162   // function.
163   Type *ScalarRetTy = I.getType()->getScalarType();
164   FunctionType *ScalarFTy =
165       FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
166   const std::string MangledName = VD->getVectorFunctionABIVariantString();
167   auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
168   if (!OptInfo)
169     return false;
170 
171   FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
172   if (!VectorFTy)
173     return false;
174 
175   Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
176                                      VD->getVectorFnName(), FuncToReplace);
177   replaceWithTLIFunction(I, *OptInfo, TLIFunc);
178   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
179                     << "` with call to `" << TLIFunc->getName() << "`.\n");
180   ++NumCallsReplaced;
181   return true;
182 }
183 
184 /// Supported instruction \p I must be a vectorized frem or a call to an
185 /// intrinsic that returns either void or a vector.
186 static bool isSupportedInstruction(Instruction *I) {
187   Type *Ty = I->getType();
188   if (auto *CI = dyn_cast<CallInst>(I))
189     return (Ty->isVectorTy() || Ty->isVoidTy()) && CI->getCalledFunction() &&
190            CI->getCalledFunction()->getIntrinsicID() !=
191                Intrinsic::not_intrinsic;
192   if (I->getOpcode() == Instruction::FRem && Ty->isVectorTy())
193     return true;
194   return false;
195 }
196 
197 static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
198   bool Changed = false;
199   SmallVector<Instruction *> ReplacedCalls;
200   for (auto &I : instructions(F)) {
201     if (!isSupportedInstruction(&I))
202       continue;
203     if (replaceWithCallToVeclib(TLI, I)) {
204       ReplacedCalls.push_back(&I);
205       Changed = true;
206     }
207   }
208   // Erase the calls to the intrinsics that have been replaced
209   // with calls to the vector library.
210   for (auto *CI : ReplacedCalls)
211     CI->eraseFromParent();
212   return Changed;
213 }
214 
215 ////////////////////////////////////////////////////////////////////////////////
216 // New pass manager implementation.
217 ////////////////////////////////////////////////////////////////////////////////
218 PreservedAnalyses ReplaceWithVeclib::run(Function &F,
219                                          FunctionAnalysisManager &AM) {
220   const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
221   auto Changed = runImpl(TLI, F);
222   if (Changed) {
223     PreservedAnalyses PA;
224     PA.preserveSet<CFGAnalyses>();
225     PA.preserve<TargetLibraryAnalysis>();
226     PA.preserve<ScalarEvolutionAnalysis>();
227     PA.preserve<LoopAccessAnalysis>();
228     PA.preserve<DemandedBitsAnalysis>();
229     PA.preserve<OptimizationRemarkEmitterAnalysis>();
230     return PA;
231   }
232 
233   // The pass did not replace any calls, hence it preserves all analyses.
234   return PreservedAnalyses::all();
235 }
236 
237 ////////////////////////////////////////////////////////////////////////////////
238 // Legacy PM Implementation.
239 ////////////////////////////////////////////////////////////////////////////////
240 bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
241   const TargetLibraryInfo &TLI =
242       getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
243   return runImpl(TLI, F);
244 }
245 
246 void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
247   AU.setPreservesCFG();
248   AU.addRequired<TargetLibraryInfoWrapperPass>();
249   AU.addPreserved<TargetLibraryInfoWrapperPass>();
250   AU.addPreserved<ScalarEvolutionWrapperPass>();
251   AU.addPreserved<AAResultsWrapperPass>();
252   AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
253   AU.addPreserved<GlobalsAAWrapperPass>();
254 }
255 
256 ////////////////////////////////////////////////////////////////////////////////
257 // Legacy Pass manager initialization
258 ////////////////////////////////////////////////////////////////////////////////
259 char ReplaceWithVeclibLegacy::ID = 0;
260 
261 INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
262                       "Replace intrinsics with calls to vector library", false,
263                       false)
264 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
265 INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
266                     "Replace intrinsics with calls to vector library", false,
267                     false)
268 
269 FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
270   return new ReplaceWithVeclibLegacy();
271 }
272