1 //=== ReplaceWithVeclib.cpp - Replace vector intrinsics with veclib calls -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Replaces calls to LLVM Intrinsics with matching calls to functions from a
10 // vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/CodeGen/ReplaceWithVeclib.h"
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Analysis/DemandedBits.h"
19 #include "llvm/Analysis/GlobalsModRef.h"
20 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
21 #include "llvm/Analysis/TargetLibraryInfo.h"
22 #include "llvm/Analysis/VectorUtils.h"
23 #include "llvm/CodeGen/Passes.h"
24 #include "llvm/IR/DerivedTypes.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/InstIterator.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/VFABIDemangler.h"
29 #include "llvm/Support/TypeSize.h"
30 #include "llvm/Transforms/Utils/ModuleUtils.h"
31
32 using namespace llvm;
33
34 #define DEBUG_TYPE "replace-with-veclib"
35
36 STATISTIC(NumCallsReplaced,
37 "Number of calls to intrinsics that have been replaced.");
38
39 STATISTIC(NumTLIFuncDeclAdded,
40 "Number of vector library function declarations added.");
41
42 STATISTIC(NumFuncUsedAdded,
43 "Number of functions added to `llvm.compiler.used`");
44
45 /// Returns a vector Function that it adds to the Module \p M. When an \p
46 /// ScalarFunc is not null, it copies its attributes to the newly created
47 /// Function.
getTLIFunction(Module * M,FunctionType * VectorFTy,const StringRef TLIName,Function * ScalarFunc=nullptr)48 Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
49 const StringRef TLIName,
50 Function *ScalarFunc = nullptr) {
51 Function *TLIFunc = M->getFunction(TLIName);
52 if (!TLIFunc) {
53 TLIFunc =
54 Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M);
55 if (ScalarFunc)
56 TLIFunc->copyAttributesFrom(ScalarFunc);
57
58 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
59 << TLIName << "` of type `" << *(TLIFunc->getType())
60 << "` to module.\n");
61
62 ++NumTLIFuncDeclAdded;
63 // Add the freshly created function to llvm.compiler.used, similar to as it
64 // is done in InjectTLIMappings.
65 appendToCompilerUsed(*M, {TLIFunc});
66 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
67 << "` to `@llvm.compiler.used`.\n");
68 ++NumFuncUsedAdded;
69 }
70 return TLIFunc;
71 }
72
73 /// Replace the intrinsic call \p II to \p TLIVecFunc, which is the
74 /// corresponding function from the vector library.
replaceWithTLIFunction(IntrinsicInst * II,VFInfo & Info,Function * TLIVecFunc)75 static void replaceWithTLIFunction(IntrinsicInst *II, VFInfo &Info,
76 Function *TLIVecFunc) {
77 IRBuilder<> IRBuilder(II);
78 SmallVector<Value *> Args(II->args());
79 if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
80 auto *MaskTy =
81 VectorType::get(Type::getInt1Ty(II->getContext()), Info.Shape.VF);
82 Args.insert(Args.begin() + OptMaskpos.value(),
83 Constant::getAllOnesValue(MaskTy));
84 }
85
86 // Preserve the operand bundles.
87 SmallVector<OperandBundleDef, 1> OpBundles;
88 II->getOperandBundlesAsDefs(OpBundles);
89
90 auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
91 II->replaceAllUsesWith(Replacement);
92 // Preserve fast math flags for FP math.
93 if (isa<FPMathOperator>(Replacement))
94 Replacement->copyFastMathFlags(II);
95 }
96
97 /// Returns true when successfully replaced \p II, which is a call to a
98 /// vectorized intrinsic, with a suitable function taking vector arguments,
99 /// based on available mappings in the \p TLI.
replaceWithCallToVeclib(const TargetLibraryInfo & TLI,IntrinsicInst * II)100 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
101 IntrinsicInst *II) {
102 assert(II != nullptr && "Intrinsic cannot be null");
103 Intrinsic::ID IID = II->getIntrinsicID();
104 Type *RetTy = II->getType();
105 Type *ScalarRetTy = RetTy->getScalarType();
106 // At the moment VFABI assumes the return type is always widened unless it is
107 // a void type.
108 auto *VTy = dyn_cast<VectorType>(RetTy);
109 ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
110
111 // OloadTys collects types used in scalar intrinsic overload name.
112 SmallVector<Type *, 3> OloadTys;
113 if (!RetTy->isVoidTy() &&
114 isVectorIntrinsicWithOverloadTypeAtArg(IID, -1, /*TTI=*/nullptr))
115 OloadTys.push_back(ScalarRetTy);
116
117 // Compute the argument types of the corresponding scalar call and check that
118 // all vector operands match the previously found EC.
119 SmallVector<Type *, 8> ScalarArgTypes;
120 for (auto Arg : enumerate(II->args())) {
121 auto *ArgTy = Arg.value()->getType();
122 bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(IID, Arg.index(),
123 /*TTI=*/nullptr);
124 if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index(), /*TTI=*/nullptr)) {
125 ScalarArgTypes.push_back(ArgTy);
126 if (IsOloadTy)
127 OloadTys.push_back(ArgTy);
128 } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
129 auto *ScalarArgTy = VectorArgTy->getElementType();
130 ScalarArgTypes.push_back(ScalarArgTy);
131 if (IsOloadTy)
132 OloadTys.push_back(ScalarArgTy);
133 // When return type is void, set EC to the first vector argument, and
134 // disallow vector arguments with different ECs.
135 if (EC.isZero())
136 EC = VectorArgTy->getElementCount();
137 else if (EC != VectorArgTy->getElementCount())
138 return false;
139 } else
140 // Exit when it is supposed to be a vector argument but it isn't.
141 return false;
142 }
143
144 // Try to reconstruct the name for the scalar version of the instruction,
145 // using scalar argument types.
146 std::string ScalarName =
147 Intrinsic::isOverloaded(IID)
148 ? Intrinsic::getName(IID, OloadTys, II->getModule())
149 : Intrinsic::getName(IID).str();
150
151 // Try to find the mapping for the scalar version of this intrinsic and the
152 // exact vector width of the call operands in the TargetLibraryInfo. First,
153 // check with a non-masked variant, and if that fails try with a masked one.
154 const VecDesc *VD =
155 TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ false);
156 if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ true)))
157 return false;
158
159 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName
160 << "` and vector width " << EC << " to: `"
161 << VD->getVectorFnName() << "`.\n");
162
163 // Replace the call to the intrinsic with a call to the vector library
164 // function.
165 FunctionType *ScalarFTy =
166 FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
167 const std::string MangledName = VD->getVectorFunctionABIVariantString();
168 auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
169 if (!OptInfo)
170 return false;
171
172 // There is no guarantee that the vectorized instructions followed the VFABI
173 // specification when being created, this is why we need to add extra check to
174 // make sure that the operands of the vector function obtained via VFABI match
175 // the operands of the original vector instruction.
176 for (auto &VFParam : OptInfo->Shape.Parameters) {
177 if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
178 continue;
179
180 // tryDemangleForVFABI must return valid ParamPos, otherwise it could be
181 // a bug in the VFABI parser.
182 assert(VFParam.ParamPos < II->arg_size() && "ParamPos has invalid range");
183 Type *OrigTy = II->getArgOperand(VFParam.ParamPos)->getType();
184 if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
185 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
186 << ". Wrong type at index " << VFParam.ParamPos << ": "
187 << *OrigTy << "\n");
188 return false;
189 }
190 }
191
192 FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
193 if (!VectorFTy)
194 return false;
195
196 Function *TLIFunc =
197 getTLIFunction(II->getModule(), VectorFTy, VD->getVectorFnName(),
198 II->getCalledFunction());
199 replaceWithTLIFunction(II, *OptInfo, TLIFunc);
200 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
201 << "` with call to `" << TLIFunc->getName() << "`.\n");
202 ++NumCallsReplaced;
203 return true;
204 }
205
runImpl(const TargetLibraryInfo & TLI,Function & F)206 static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
207 SmallVector<Instruction *> ReplacedCalls;
208 for (auto &I : instructions(F)) {
209 // Process only intrinsic calls that return void or a vector.
210 if (auto *II = dyn_cast<IntrinsicInst>(&I)) {
211 if (II->getIntrinsicID() == Intrinsic::not_intrinsic)
212 continue;
213 if (!II->getType()->isVectorTy() && !II->getType()->isVoidTy())
214 continue;
215
216 if (replaceWithCallToVeclib(TLI, II))
217 ReplacedCalls.push_back(&I);
218 }
219 }
220 // Erase any intrinsic calls that were replaced with vector library calls.
221 for (auto *I : ReplacedCalls)
222 I->eraseFromParent();
223 return !ReplacedCalls.empty();
224 }
225
226 ////////////////////////////////////////////////////////////////////////////////
227 // New pass manager implementation.
228 ////////////////////////////////////////////////////////////////////////////////
run(Function & F,FunctionAnalysisManager & AM)229 PreservedAnalyses ReplaceWithVeclib::run(Function &F,
230 FunctionAnalysisManager &AM) {
231 const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
232 auto Changed = runImpl(TLI, F);
233 if (Changed) {
234 LLVM_DEBUG(dbgs() << "Intrinsic calls replaced with vector libraries: "
235 << NumCallsReplaced << "\n");
236
237 PreservedAnalyses PA;
238 PA.preserveSet<CFGAnalyses>();
239 PA.preserve<TargetLibraryAnalysis>();
240 PA.preserve<ScalarEvolutionAnalysis>();
241 PA.preserve<LoopAccessAnalysis>();
242 PA.preserve<DemandedBitsAnalysis>();
243 PA.preserve<OptimizationRemarkEmitterAnalysis>();
244 return PA;
245 }
246
247 // The pass did not replace any calls, hence it preserves all analyses.
248 return PreservedAnalyses::all();
249 }
250
251 ////////////////////////////////////////////////////////////////////////////////
252 // Legacy PM Implementation.
253 ////////////////////////////////////////////////////////////////////////////////
runOnFunction(Function & F)254 bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
255 const TargetLibraryInfo &TLI =
256 getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
257 return runImpl(TLI, F);
258 }
259
getAnalysisUsage(AnalysisUsage & AU) const260 void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
261 AU.setPreservesCFG();
262 AU.addRequired<TargetLibraryInfoWrapperPass>();
263 AU.addPreserved<TargetLibraryInfoWrapperPass>();
264 AU.addPreserved<ScalarEvolutionWrapperPass>();
265 AU.addPreserved<AAResultsWrapperPass>();
266 AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
267 AU.addPreserved<GlobalsAAWrapperPass>();
268 }
269
270 ////////////////////////////////////////////////////////////////////////////////
271 // Legacy Pass manager initialization
272 ////////////////////////////////////////////////////////////////////////////////
273 char ReplaceWithVeclibLegacy::ID = 0;
274
275 INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
276 "Replace intrinsics with calls to vector library", false,
277 false)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)278 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
279 INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
280 "Replace intrinsics with calls to vector library", false,
281 false)
282
283 FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
284 return new ReplaceWithVeclibLegacy();
285 }
286