xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/ReplaceWithVeclib.cpp (revision 8311bc5f17dec348749f763b82dfe2737bc53cd7)
1 //=== ReplaceWithVeclib.cpp - Replace vector intrinsics with veclib calls -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Replaces calls to LLVM vector intrinsics (i.e., calls to LLVM intrinsics
10 // with vector operands) with matching calls to functions from a vector
11 // library (e.g., libmvec, SVML) according to TargetLibraryInfo.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/CodeGen/ReplaceWithVeclib.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/DemandedBits.h"
19 #include "llvm/Analysis/GlobalsModRef.h"
20 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
21 #include "llvm/Analysis/TargetLibraryInfo.h"
22 #include "llvm/Analysis/VectorUtils.h"
23 #include "llvm/CodeGen/Passes.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstIterator.h"
26 #include "llvm/Transforms/Utils/ModuleUtils.h"
27 
28 using namespace llvm;
29 
30 #define DEBUG_TYPE "replace-with-veclib"
31 
32 STATISTIC(NumCallsReplaced,
33           "Number of calls to intrinsics that have been replaced.");
34 
35 STATISTIC(NumTLIFuncDeclAdded,
36           "Number of vector library function declarations added.");
37 
38 STATISTIC(NumFuncUsedAdded,
39           "Number of functions added to `llvm.compiler.used`");
40 
41 static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) {
42   Module *M = CI.getModule();
43 
44   Function *OldFunc = CI.getCalledFunction();
45 
46   // Check if the vector library function is already declared in this module,
47   // otherwise insert it.
48   Function *TLIFunc = M->getFunction(TLIName);
49   if (!TLIFunc) {
50     TLIFunc = Function::Create(OldFunc->getFunctionType(),
51                                Function::ExternalLinkage, TLIName, *M);
52     TLIFunc->copyAttributesFrom(OldFunc);
53 
54     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
55                       << TLIName << "` of type `" << *(TLIFunc->getType())
56                       << "` to module.\n");
57 
58     ++NumTLIFuncDeclAdded;
59 
60     // Add the freshly created function to llvm.compiler.used,
61     // similar to as it is done in InjectTLIMappings
62     appendToCompilerUsed(*M, {TLIFunc});
63 
64     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
65                       << "` to `@llvm.compiler.used`.\n");
66     ++NumFuncUsedAdded;
67   }
68 
69   // Replace the call to the vector intrinsic with a call
70   // to the corresponding function from the vector library.
71   IRBuilder<> IRBuilder(&CI);
72   SmallVector<Value *> Args(CI.args());
73   // Preserve the operand bundles.
74   SmallVector<OperandBundleDef, 1> OpBundles;
75   CI.getOperandBundlesAsDefs(OpBundles);
76   CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles);
77   assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() &&
78          "Expecting function types to be identical");
79   CI.replaceAllUsesWith(Replacement);
80   if (isa<FPMathOperator>(Replacement)) {
81     // Preserve fast math flags for FP math.
82     Replacement->copyFastMathFlags(&CI);
83   }
84 
85   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
86                     << OldFunc->getName() << "` with call to `" << TLIName
87                     << "`.\n");
88   ++NumCallsReplaced;
89   return true;
90 }
91 
92 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
93                                     CallInst &CI) {
94   if (!CI.getCalledFunction()) {
95     return false;
96   }
97 
98   auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID();
99   if (IntrinsicID == Intrinsic::not_intrinsic) {
100     // Replacement is only performed for intrinsic functions
101     return false;
102   }
103 
104   // Convert vector arguments to scalar type and check that
105   // all vector operands have identical vector width.
106   ElementCount VF = ElementCount::getFixed(0);
107   SmallVector<Type *> ScalarTypes;
108   for (auto Arg : enumerate(CI.args())) {
109     auto *ArgType = Arg.value()->getType();
110     // Vector calls to intrinsics can still have
111     // scalar operands for specific arguments.
112     if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
113       ScalarTypes.push_back(ArgType);
114     } else {
115       // The argument in this place should be a vector if
116       // this is a call to a vector intrinsic.
117       auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
118       if (!VectorArgTy) {
119         // The argument is not a vector, do not perform
120         // the replacement.
121         return false;
122       }
123       ElementCount NumElements = VectorArgTy->getElementCount();
124       if (NumElements.isScalable()) {
125         // The current implementation does not support
126         // scalable vectors.
127         return false;
128       }
129       if (VF.isNonZero() && VF != NumElements) {
130         // The different arguments differ in vector size.
131         return false;
132       } else {
133         VF = NumElements;
134       }
135       ScalarTypes.push_back(VectorArgTy->getElementType());
136     }
137   }
138 
139   // Try to reconstruct the name for the scalar version of this
140   // intrinsic using the intrinsic ID and the argument types
141   // converted to scalar above.
142   std::string ScalarName;
143   if (Intrinsic::isOverloaded(IntrinsicID)) {
144     ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule());
145   } else {
146     ScalarName = Intrinsic::getName(IntrinsicID).str();
147   }
148 
149   if (!TLI.isFunctionVectorizable(ScalarName)) {
150     // The TargetLibraryInfo does not contain a vectorized version of
151     // the scalar function.
152     return false;
153   }
154 
155   // Try to find the mapping for the scalar version of this intrinsic
156   // and the exact vector width of the call operands in the
157   // TargetLibraryInfo.
158   const std::string TLIName =
159       std::string(TLI.getVectorizedFunction(ScalarName, VF));
160 
161   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `"
162                     << ScalarName << "` and vector width " << VF << ".\n");
163 
164   if (!TLIName.empty()) {
165     // Found the correct mapping in the TargetLibraryInfo,
166     // replace the call to the intrinsic with a call to
167     // the vector library function.
168     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName
169                       << "`.\n");
170     return replaceWithTLIFunction(CI, TLIName);
171   }
172 
173   return false;
174 }
175 
176 static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
177   bool Changed = false;
178   SmallVector<CallInst *> ReplacedCalls;
179   for (auto &I : instructions(F)) {
180     if (auto *CI = dyn_cast<CallInst>(&I)) {
181       if (replaceWithCallToVeclib(TLI, *CI)) {
182         ReplacedCalls.push_back(CI);
183         Changed = true;
184       }
185     }
186   }
187   // Erase the calls to the intrinsics that have been replaced
188   // with calls to the vector library.
189   for (auto *CI : ReplacedCalls) {
190     CI->eraseFromParent();
191   }
192   return Changed;
193 }
194 
195 ////////////////////////////////////////////////////////////////////////////////
196 // New pass manager implementation.
197 ////////////////////////////////////////////////////////////////////////////////
198 PreservedAnalyses ReplaceWithVeclib::run(Function &F,
199                                          FunctionAnalysisManager &AM) {
200   const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
201   auto Changed = runImpl(TLI, F);
202   if (Changed) {
203     PreservedAnalyses PA;
204     PA.preserveSet<CFGAnalyses>();
205     PA.preserve<TargetLibraryAnalysis>();
206     PA.preserve<ScalarEvolutionAnalysis>();
207     PA.preserve<LoopAccessAnalysis>();
208     PA.preserve<DemandedBitsAnalysis>();
209     PA.preserve<OptimizationRemarkEmitterAnalysis>();
210     return PA;
211   } else {
212     // The pass did not replace any calls, hence it preserves all analyses.
213     return PreservedAnalyses::all();
214   }
215 }
216 
217 ////////////////////////////////////////////////////////////////////////////////
218 // Legacy PM Implementation.
219 ////////////////////////////////////////////////////////////////////////////////
220 bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
221   const TargetLibraryInfo &TLI =
222       getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
223   return runImpl(TLI, F);
224 }
225 
226 void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
227   AU.setPreservesCFG();
228   AU.addRequired<TargetLibraryInfoWrapperPass>();
229   AU.addPreserved<TargetLibraryInfoWrapperPass>();
230   AU.addPreserved<ScalarEvolutionWrapperPass>();
231   AU.addPreserved<AAResultsWrapperPass>();
232   AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
233   AU.addPreserved<GlobalsAAWrapperPass>();
234 }
235 
236 ////////////////////////////////////////////////////////////////////////////////
237 // Legacy Pass manager initialization
238 ////////////////////////////////////////////////////////////////////////////////
239 char ReplaceWithVeclibLegacy::ID = 0;
240 
241 INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
242                       "Replace intrinsics with calls to vector library", false,
243                       false)
244 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
245 INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
246                     "Replace intrinsics with calls to vector library", false,
247                     false)
248 
249 FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
250   return new ReplaceWithVeclibLegacy();
251 }
252