xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/ReplaceWithVeclib.cpp (revision 271171e0d97b88ba2a7c3bf750c9672b484c1c13)
1 //=== ReplaceWithVeclib.cpp - Replace vector instrinsics with veclib calls ===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Replaces calls to LLVM vector intrinsics (i.e., calls to LLVM intrinsics
10 // with vector operands) with matching calls to functions from a vector
11 // library (e.g., libmvec, SVML) according to TargetLibraryInfo.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/CodeGen/ReplaceWithVeclib.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/DemandedBits.h"
19 #include "llvm/Analysis/GlobalsModRef.h"
20 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
21 #include "llvm/Analysis/TargetLibraryInfo.h"
22 #include "llvm/Analysis/VectorUtils.h"
23 #include "llvm/CodeGen/Passes.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstIterator.h"
26 #include "llvm/IR/IntrinsicInst.h"
27 #include "llvm/Transforms/Utils/ModuleUtils.h"
28 
29 using namespace llvm;
30 
31 #define DEBUG_TYPE "replace-with-veclib"
32 
33 STATISTIC(NumCallsReplaced,
34           "Number of calls to intrinsics that have been replaced.");
35 
36 STATISTIC(NumTLIFuncDeclAdded,
37           "Number of vector library function declarations added.");
38 
39 STATISTIC(NumFuncUsedAdded,
40           "Number of functions added to `llvm.compiler.used`");
41 
42 static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) {
43   Module *M = CI.getModule();
44 
45   Function *OldFunc = CI.getCalledFunction();
46 
47   // Check if the vector library function is already declared in this module,
48   // otherwise insert it.
49   Function *TLIFunc = M->getFunction(TLIName);
50   if (!TLIFunc) {
51     TLIFunc = Function::Create(OldFunc->getFunctionType(),
52                                Function::ExternalLinkage, TLIName, *M);
53     TLIFunc->copyAttributesFrom(OldFunc);
54 
55     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
56                       << TLIName << "` of type `" << *(TLIFunc->getType())
57                       << "` to module.\n");
58 
59     ++NumTLIFuncDeclAdded;
60 
61     // Add the freshly created function to llvm.compiler.used,
62     // similar to as it is done in InjectTLIMappings
63     appendToCompilerUsed(*M, {TLIFunc});
64 
65     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
66                       << "` to `@llvm.compiler.used`.\n");
67     ++NumFuncUsedAdded;
68   }
69 
70   // Replace the call to the vector intrinsic with a call
71   // to the corresponding function from the vector library.
72   IRBuilder<> IRBuilder(&CI);
73   SmallVector<Value *> Args(CI.args());
74   // Preserve the operand bundles.
75   SmallVector<OperandBundleDef, 1> OpBundles;
76   CI.getOperandBundlesAsDefs(OpBundles);
77   CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles);
78   assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() &&
79          "Expecting function types to be identical");
80   CI.replaceAllUsesWith(Replacement);
81   if (isa<FPMathOperator>(Replacement)) {
82     // Preserve fast math flags for FP math.
83     Replacement->copyFastMathFlags(&CI);
84   }
85 
86   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
87                     << OldFunc->getName() << "` with call to `" << TLIName
88                     << "`.\n");
89   ++NumCallsReplaced;
90   return true;
91 }
92 
93 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
94                                     CallInst &CI) {
95   if (!CI.getCalledFunction()) {
96     return false;
97   }
98 
99   auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID();
100   if (IntrinsicID == Intrinsic::not_intrinsic) {
101     // Replacement is only performed for intrinsic functions
102     return false;
103   }
104 
105   // Convert vector arguments to scalar type and check that
106   // all vector operands have identical vector width.
107   ElementCount VF = ElementCount::getFixed(0);
108   SmallVector<Type *> ScalarTypes;
109   for (auto Arg : enumerate(CI.args())) {
110     auto *ArgType = Arg.value()->getType();
111     // Vector calls to intrinsics can still have
112     // scalar operands for specific arguments.
113     if (hasVectorInstrinsicScalarOpd(IntrinsicID, Arg.index())) {
114       ScalarTypes.push_back(ArgType);
115     } else {
116       // The argument in this place should be a vector if
117       // this is a call to a vector intrinsic.
118       auto *VectorArgTy = dyn_cast<VectorType>(ArgType);
119       if (!VectorArgTy) {
120         // The argument is not a vector, do not perform
121         // the replacement.
122         return false;
123       }
124       ElementCount NumElements = VectorArgTy->getElementCount();
125       if (NumElements.isScalable()) {
126         // The current implementation does not support
127         // scalable vectors.
128         return false;
129       }
130       if (VF.isNonZero() && VF != NumElements) {
131         // The different arguments differ in vector size.
132         return false;
133       } else {
134         VF = NumElements;
135       }
136       ScalarTypes.push_back(VectorArgTy->getElementType());
137     }
138   }
139 
140   // Try to reconstruct the name for the scalar version of this
141   // intrinsic using the intrinsic ID and the argument types
142   // converted to scalar above.
143   std::string ScalarName;
144   if (Intrinsic::isOverloaded(IntrinsicID)) {
145     ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule());
146   } else {
147     ScalarName = Intrinsic::getName(IntrinsicID).str();
148   }
149 
150   if (!TLI.isFunctionVectorizable(ScalarName)) {
151     // The TargetLibraryInfo does not contain a vectorized version of
152     // the scalar function.
153     return false;
154   }
155 
156   // Try to find the mapping for the scalar version of this intrinsic
157   // and the exact vector width of the call operands in the
158   // TargetLibraryInfo.
159   const std::string TLIName =
160       std::string(TLI.getVectorizedFunction(ScalarName, VF));
161 
162   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `"
163                     << ScalarName << "` and vector width " << VF << ".\n");
164 
165   if (!TLIName.empty()) {
166     // Found the correct mapping in the TargetLibraryInfo,
167     // replace the call to the intrinsic with a call to
168     // the vector library function.
169     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName
170                       << "`.\n");
171     return replaceWithTLIFunction(CI, TLIName);
172   }
173 
174   return false;
175 }
176 
177 static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
178   bool Changed = false;
179   SmallVector<CallInst *> ReplacedCalls;
180   for (auto &I : instructions(F)) {
181     if (auto *CI = dyn_cast<CallInst>(&I)) {
182       if (replaceWithCallToVeclib(TLI, *CI)) {
183         ReplacedCalls.push_back(CI);
184         Changed = true;
185       }
186     }
187   }
188   // Erase the calls to the intrinsics that have been replaced
189   // with calls to the vector library.
190   for (auto *CI : ReplacedCalls) {
191     CI->eraseFromParent();
192   }
193   return Changed;
194 }
195 
196 ////////////////////////////////////////////////////////////////////////////////
197 // New pass manager implementation.
198 ////////////////////////////////////////////////////////////////////////////////
199 PreservedAnalyses ReplaceWithVeclib::run(Function &F,
200                                          FunctionAnalysisManager &AM) {
201   const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
202   auto Changed = runImpl(TLI, F);
203   if (Changed) {
204     PreservedAnalyses PA;
205     PA.preserveSet<CFGAnalyses>();
206     PA.preserve<TargetLibraryAnalysis>();
207     PA.preserve<ScalarEvolutionAnalysis>();
208     PA.preserve<LoopAccessAnalysis>();
209     PA.preserve<DemandedBitsAnalysis>();
210     PA.preserve<OptimizationRemarkEmitterAnalysis>();
211     return PA;
212   } else {
213     // The pass did not replace any calls, hence it preserves all analyses.
214     return PreservedAnalyses::all();
215   }
216 }
217 
218 ////////////////////////////////////////////////////////////////////////////////
219 // Legacy PM Implementation.
220 ////////////////////////////////////////////////////////////////////////////////
221 bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
222   const TargetLibraryInfo &TLI =
223       getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
224   return runImpl(TLI, F);
225 }
226 
227 void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
228   AU.setPreservesCFG();
229   AU.addRequired<TargetLibraryInfoWrapperPass>();
230   AU.addPreserved<TargetLibraryInfoWrapperPass>();
231   AU.addPreserved<ScalarEvolutionWrapperPass>();
232   AU.addPreserved<AAResultsWrapperPass>();
233   AU.addPreserved<LoopAccessLegacyAnalysis>();
234   AU.addPreserved<DemandedBitsWrapperPass>();
235   AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
236   AU.addPreserved<GlobalsAAWrapperPass>();
237 }
238 
239 ////////////////////////////////////////////////////////////////////////////////
240 // Legacy Pass manager initialization
241 ////////////////////////////////////////////////////////////////////////////////
242 char ReplaceWithVeclibLegacy::ID = 0;
243 
244 INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
245                       "Replace intrinsics with calls to vector library", false,
246                       false)
247 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
248 INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
249                     "Replace intrinsics with calls to vector library", false,
250                     false)
251 
252 FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
253   return new ReplaceWithVeclibLegacy();
254 }
255