xref: /freebsd/contrib/llvm-project/llvm/lib/CodeGen/ReplaceWithVeclib.cpp (revision a2fda816eb054d5873be223ef2461741dfcc253c)
1  //=== ReplaceWithVeclib.cpp - Replace vector intrinsics with veclib calls -===//
2  //
3  // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4  // See https://llvm.org/LICENSE.txt for license information.
5  // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6  //
7  //===----------------------------------------------------------------------===//
8  //
9  // Replaces LLVM IR instructions with vector operands (i.e., the frem
10  // instruction or calls to LLVM intrinsics) with matching calls to functions
11  // from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
12  //
13  //===----------------------------------------------------------------------===//
14  
15  #include "llvm/CodeGen/ReplaceWithVeclib.h"
16  #include "llvm/ADT/STLExtras.h"
17  #include "llvm/ADT/Statistic.h"
18  #include "llvm/ADT/StringRef.h"
19  #include "llvm/Analysis/DemandedBits.h"
20  #include "llvm/Analysis/GlobalsModRef.h"
21  #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22  #include "llvm/Analysis/TargetLibraryInfo.h"
23  #include "llvm/Analysis/VectorUtils.h"
24  #include "llvm/CodeGen/Passes.h"
25  #include "llvm/IR/DerivedTypes.h"
26  #include "llvm/IR/IRBuilder.h"
27  #include "llvm/IR/InstIterator.h"
28  #include "llvm/IR/VFABIDemangler.h"
29  #include "llvm/Support/TypeSize.h"
30  #include "llvm/Transforms/Utils/ModuleUtils.h"
31  
32  using namespace llvm;
33  
34  #define DEBUG_TYPE "replace-with-veclib"
35  
36  STATISTIC(NumCallsReplaced,
37            "Number of calls to intrinsics that have been replaced.");
38  
39  STATISTIC(NumTLIFuncDeclAdded,
40            "Number of vector library function declarations added.");
41  
42  STATISTIC(NumFuncUsedAdded,
43            "Number of functions added to `llvm.compiler.used`");
44  
45  /// Returns a vector Function that it adds to the Module \p M. When an \p
46  /// ScalarFunc is not null, it copies its attributes to the newly created
47  /// Function.
48  Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
49                           const StringRef TLIName,
50                           Function *ScalarFunc = nullptr) {
51    Function *TLIFunc = M->getFunction(TLIName);
52    if (!TLIFunc) {
53      TLIFunc =
54          Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M);
55      if (ScalarFunc)
56        TLIFunc->copyAttributesFrom(ScalarFunc);
57  
58      LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
59                        << TLIName << "` of type `" << *(TLIFunc->getType())
60                        << "` to module.\n");
61  
62      ++NumTLIFuncDeclAdded;
63      // Add the freshly created function to llvm.compiler.used, similar to as it
64      // is done in InjectTLIMappings.
65      appendToCompilerUsed(*M, {TLIFunc});
66      LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
67                        << "` to `@llvm.compiler.used`.\n");
68      ++NumFuncUsedAdded;
69    }
70    return TLIFunc;
71  }
72  
73  /// Replace the instruction \p I with a call to the corresponding function from
74  /// the vector library (\p TLIVecFunc).
75  static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
76                                     Function *TLIVecFunc) {
77    IRBuilder<> IRBuilder(&I);
78    auto *CI = dyn_cast<CallInst>(&I);
79    SmallVector<Value *> Args(CI ? CI->args() : I.operands());
80    if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
81      auto *MaskTy =
82          VectorType::get(Type::getInt1Ty(I.getContext()), Info.Shape.VF);
83      Args.insert(Args.begin() + OptMaskpos.value(),
84                  Constant::getAllOnesValue(MaskTy));
85    }
86  
87    // If it is a call instruction, preserve the operand bundles.
88    SmallVector<OperandBundleDef, 1> OpBundles;
89    if (CI)
90      CI->getOperandBundlesAsDefs(OpBundles);
91  
92    auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
93    I.replaceAllUsesWith(Replacement);
94    // Preserve fast math flags for FP math.
95    if (isa<FPMathOperator>(Replacement))
96      Replacement->copyFastMathFlags(&I);
97  }
98  
99  /// Returns true when successfully replaced \p I with a suitable function taking
100  /// vector arguments, based on available mappings in the \p TLI. Currently only
101  /// works when \p I is a call to vectorized intrinsic or the frem instruction.
102  static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
103                                      Instruction &I) {
104    // At the moment VFABI assumes the return type is always widened unless it is
105    // a void type.
106    auto *VTy = dyn_cast<VectorType>(I.getType());
107    ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
108  
109    // Compute the argument types of the corresponding scalar call and the scalar
110    // function name. For calls, it additionally finds the function to replace
111    // and checks that all vector operands match the previously found EC.
112    SmallVector<Type *, 8> ScalarArgTypes;
113    std::string ScalarName;
114    Function *FuncToReplace = nullptr;
115    auto *CI = dyn_cast<CallInst>(&I);
116    if (CI) {
117      FuncToReplace = CI->getCalledFunction();
118      Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
119      assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
120      for (auto Arg : enumerate(CI->args())) {
121        auto *ArgTy = Arg.value()->getType();
122        if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
123          ScalarArgTypes.push_back(ArgTy);
124        } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
125          ScalarArgTypes.push_back(VectorArgTy->getElementType());
126          // When return type is void, set EC to the first vector argument, and
127          // disallow vector arguments with different ECs.
128          if (EC.isZero())
129            EC = VectorArgTy->getElementCount();
130          else if (EC != VectorArgTy->getElementCount())
131            return false;
132        } else
133          // Exit when it is supposed to be a vector argument but it isn't.
134          return false;
135      }
136      // Try to reconstruct the name for the scalar version of the instruction,
137      // using scalar argument types.
138      ScalarName = Intrinsic::isOverloaded(IID)
139                       ? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
140                       : Intrinsic::getName(IID).str();
141    } else {
142      assert(VTy && "Return type must be a vector");
143      auto *ScalarTy = VTy->getScalarType();
144      LibFunc Func;
145      if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
146        return false;
147      ScalarName = TLI.getName(Func);
148      ScalarArgTypes = {ScalarTy, ScalarTy};
149    }
150  
151    // Try to find the mapping for the scalar version of this intrinsic and the
152    // exact vector width of the call operands in the TargetLibraryInfo. First,
153    // check with a non-masked variant, and if that fails try with a masked one.
154    const VecDesc *VD =
155        TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ false);
156    if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ true)))
157      return false;
158  
159    LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName
160                      << "` and vector width " << EC << " to: `"
161                      << VD->getVectorFnName() << "`.\n");
162  
163    // Replace the call to the intrinsic with a call to the vector library
164    // function.
165    Type *ScalarRetTy = I.getType()->getScalarType();
166    FunctionType *ScalarFTy =
167        FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
168    const std::string MangledName = VD->getVectorFunctionABIVariantString();
169    auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
170    if (!OptInfo)
171      return false;
172  
173    // There is no guarantee that the vectorized instructions followed the VFABI
174    // specification when being created, this is why we need to add extra check to
175    // make sure that the operands of the vector function obtained via VFABI match
176    // the operands of the original vector instruction.
177    if (CI) {
178      for (auto VFParam : OptInfo->Shape.Parameters) {
179        if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
180          continue;
181  
182        // tryDemangleForVFABI must return valid ParamPos, otherwise it could be
183        // a bug in the VFABI parser.
184        assert(VFParam.ParamPos < CI->arg_size() &&
185               "ParamPos has invalid range.");
186        Type *OrigTy = CI->getArgOperand(VFParam.ParamPos)->getType();
187        if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
188          LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
189                            << ". Wrong type at index " << VFParam.ParamPos
190                            << ": " << *OrigTy << "\n");
191          return false;
192        }
193      }
194    }
195  
196    FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
197    if (!VectorFTy)
198      return false;
199  
200    Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
201                                       VD->getVectorFnName(), FuncToReplace);
202  
203    replaceWithTLIFunction(I, *OptInfo, TLIFunc);
204    LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
205                      << "` with call to `" << TLIFunc->getName() << "`.\n");
206    ++NumCallsReplaced;
207    return true;
208  }
209  
210  /// Supported instruction \p I must be a vectorized frem or a call to an
211  /// intrinsic that returns either void or a vector.
212  static bool isSupportedInstruction(Instruction *I) {
213    Type *Ty = I->getType();
214    if (auto *CI = dyn_cast<CallInst>(I))
215      return (Ty->isVectorTy() || Ty->isVoidTy()) && CI->getCalledFunction() &&
216             CI->getCalledFunction()->getIntrinsicID() !=
217                 Intrinsic::not_intrinsic;
218    if (I->getOpcode() == Instruction::FRem && Ty->isVectorTy())
219      return true;
220    return false;
221  }
222  
223  static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
224    bool Changed = false;
225    SmallVector<Instruction *> ReplacedCalls;
226    for (auto &I : instructions(F)) {
227      if (!isSupportedInstruction(&I))
228        continue;
229      if (replaceWithCallToVeclib(TLI, I)) {
230        ReplacedCalls.push_back(&I);
231        Changed = true;
232      }
233    }
234    // Erase the calls to the intrinsics that have been replaced
235    // with calls to the vector library.
236    for (auto *CI : ReplacedCalls)
237      CI->eraseFromParent();
238    return Changed;
239  }
240  
241  ////////////////////////////////////////////////////////////////////////////////
242  // New pass manager implementation.
243  ////////////////////////////////////////////////////////////////////////////////
244  PreservedAnalyses ReplaceWithVeclib::run(Function &F,
245                                           FunctionAnalysisManager &AM) {
246    const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
247    auto Changed = runImpl(TLI, F);
248    if (Changed) {
249      LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
250                        << NumCallsReplaced << "\n");
251  
252      PreservedAnalyses PA;
253      PA.preserveSet<CFGAnalyses>();
254      PA.preserve<TargetLibraryAnalysis>();
255      PA.preserve<ScalarEvolutionAnalysis>();
256      PA.preserve<LoopAccessAnalysis>();
257      PA.preserve<DemandedBitsAnalysis>();
258      PA.preserve<OptimizationRemarkEmitterAnalysis>();
259      return PA;
260    }
261  
262    // The pass did not replace any calls, hence it preserves all analyses.
263    return PreservedAnalyses::all();
264  }
265  
266  ////////////////////////////////////////////////////////////////////////////////
267  // Legacy PM Implementation.
268  ////////////////////////////////////////////////////////////////////////////////
269  bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
270    const TargetLibraryInfo &TLI =
271        getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
272    return runImpl(TLI, F);
273  }
274  
275  void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
276    AU.setPreservesCFG();
277    AU.addRequired<TargetLibraryInfoWrapperPass>();
278    AU.addPreserved<TargetLibraryInfoWrapperPass>();
279    AU.addPreserved<ScalarEvolutionWrapperPass>();
280    AU.addPreserved<AAResultsWrapperPass>();
281    AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
282    AU.addPreserved<GlobalsAAWrapperPass>();
283  }
284  
285  ////////////////////////////////////////////////////////////////////////////////
286  // Legacy Pass manager initialization
287  ////////////////////////////////////////////////////////////////////////////////
288  char ReplaceWithVeclibLegacy::ID = 0;
289  
290  INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
291                        "Replace intrinsics with calls to vector library", false,
292                        false)
293  INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
294  INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
295                      "Replace intrinsics with calls to vector library", false,
296                      false)
297  
298  FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
299    return new ReplaceWithVeclibLegacy();
300  }
301