1 //=== ReplaceWithVeclib.cpp - Replace vector intrinsics with veclib calls -===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Replaces LLVM IR instructions with vector operands (i.e., the frem 10 // instruction or calls to LLVM intrinsics) with matching calls to functions 11 // from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/CodeGen/ReplaceWithVeclib.h" 16 #include "llvm/ADT/STLExtras.h" 17 #include "llvm/ADT/Statistic.h" 18 #include "llvm/ADT/StringRef.h" 19 #include "llvm/Analysis/DemandedBits.h" 20 #include "llvm/Analysis/GlobalsModRef.h" 21 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 22 #include "llvm/Analysis/TargetLibraryInfo.h" 23 #include "llvm/Analysis/VectorUtils.h" 24 #include "llvm/CodeGen/Passes.h" 25 #include "llvm/IR/DerivedTypes.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/InstIterator.h" 28 #include "llvm/IR/VFABIDemangler.h" 29 #include "llvm/Support/TypeSize.h" 30 #include "llvm/Transforms/Utils/ModuleUtils.h" 31 32 using namespace llvm; 33 34 #define DEBUG_TYPE "replace-with-veclib" 35 36 STATISTIC(NumCallsReplaced, 37 "Number of calls to intrinsics that have been replaced."); 38 39 STATISTIC(NumTLIFuncDeclAdded, 40 "Number of vector library function declarations added."); 41 42 STATISTIC(NumFuncUsedAdded, 43 "Number of functions added to `llvm.compiler.used`"); 44 45 /// Returns a vector Function that it adds to the Module \p M. When an \p 46 /// ScalarFunc is not null, it copies its attributes to the newly created 47 /// Function. 48 Function *getTLIFunction(Module *M, FunctionType *VectorFTy, 49 const StringRef TLIName, 50 Function *ScalarFunc = nullptr) { 51 Function *TLIFunc = M->getFunction(TLIName); 52 if (!TLIFunc) { 53 TLIFunc = 54 Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M); 55 if (ScalarFunc) 56 TLIFunc->copyAttributesFrom(ScalarFunc); 57 58 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `" 59 << TLIName << "` of type `" << *(TLIFunc->getType()) 60 << "` to module.\n"); 61 62 ++NumTLIFuncDeclAdded; 63 // Add the freshly created function to llvm.compiler.used, similar to as it 64 // is done in InjectTLIMappings. 65 appendToCompilerUsed(*M, {TLIFunc}); 66 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName 67 << "` to `@llvm.compiler.used`.\n"); 68 ++NumFuncUsedAdded; 69 } 70 return TLIFunc; 71 } 72 73 /// Replace the instruction \p I with a call to the corresponding function from 74 /// the vector library (\p TLIVecFunc). 75 static void replaceWithTLIFunction(Instruction &I, VFInfo &Info, 76 Function *TLIVecFunc) { 77 IRBuilder<> IRBuilder(&I); 78 auto *CI = dyn_cast<CallInst>(&I); 79 SmallVector<Value *> Args(CI ? CI->args() : I.operands()); 80 if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) { 81 auto *MaskTy = 82 VectorType::get(Type::getInt1Ty(I.getContext()), Info.Shape.VF); 83 Args.insert(Args.begin() + OptMaskpos.value(), 84 Constant::getAllOnesValue(MaskTy)); 85 } 86 87 // If it is a call instruction, preserve the operand bundles. 88 SmallVector<OperandBundleDef, 1> OpBundles; 89 if (CI) 90 CI->getOperandBundlesAsDefs(OpBundles); 91 92 auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles); 93 I.replaceAllUsesWith(Replacement); 94 // Preserve fast math flags for FP math. 95 if (isa<FPMathOperator>(Replacement)) 96 Replacement->copyFastMathFlags(&I); 97 } 98 99 /// Returns true when successfully replaced \p I with a suitable function taking 100 /// vector arguments, based on available mappings in the \p TLI. Currently only 101 /// works when \p I is a call to vectorized intrinsic or the frem instruction. 102 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, 103 Instruction &I) { 104 // At the moment VFABI assumes the return type is always widened unless it is 105 // a void type. 106 auto *VTy = dyn_cast<VectorType>(I.getType()); 107 ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0)); 108 109 // Compute the argument types of the corresponding scalar call and the scalar 110 // function name. For calls, it additionally finds the function to replace 111 // and checks that all vector operands match the previously found EC. 112 SmallVector<Type *, 8> ScalarArgTypes; 113 std::string ScalarName; 114 Function *FuncToReplace = nullptr; 115 auto *CI = dyn_cast<CallInst>(&I); 116 if (CI) { 117 FuncToReplace = CI->getCalledFunction(); 118 Intrinsic::ID IID = FuncToReplace->getIntrinsicID(); 119 assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic"); 120 for (auto Arg : enumerate(CI->args())) { 121 auto *ArgTy = Arg.value()->getType(); 122 if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) { 123 ScalarArgTypes.push_back(ArgTy); 124 } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) { 125 ScalarArgTypes.push_back(VectorArgTy->getElementType()); 126 // When return type is void, set EC to the first vector argument, and 127 // disallow vector arguments with different ECs. 128 if (EC.isZero()) 129 EC = VectorArgTy->getElementCount(); 130 else if (EC != VectorArgTy->getElementCount()) 131 return false; 132 } else 133 // Exit when it is supposed to be a vector argument but it isn't. 134 return false; 135 } 136 // Try to reconstruct the name for the scalar version of the instruction, 137 // using scalar argument types. 138 ScalarName = Intrinsic::isOverloaded(IID) 139 ? Intrinsic::getName(IID, ScalarArgTypes, I.getModule()) 140 : Intrinsic::getName(IID).str(); 141 } else { 142 assert(VTy && "Return type must be a vector"); 143 auto *ScalarTy = VTy->getScalarType(); 144 LibFunc Func; 145 if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func)) 146 return false; 147 ScalarName = TLI.getName(Func); 148 ScalarArgTypes = {ScalarTy, ScalarTy}; 149 } 150 151 // Try to find the mapping for the scalar version of this intrinsic and the 152 // exact vector width of the call operands in the TargetLibraryInfo. First, 153 // check with a non-masked variant, and if that fails try with a masked one. 154 const VecDesc *VD = 155 TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ false); 156 if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ true))) 157 return false; 158 159 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName 160 << "` and vector width " << EC << " to: `" 161 << VD->getVectorFnName() << "`.\n"); 162 163 // Replace the call to the intrinsic with a call to the vector library 164 // function. 165 Type *ScalarRetTy = I.getType()->getScalarType(); 166 FunctionType *ScalarFTy = 167 FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false); 168 const std::string MangledName = VD->getVectorFunctionABIVariantString(); 169 auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy); 170 if (!OptInfo) 171 return false; 172 173 // There is no guarantee that the vectorized instructions followed the VFABI 174 // specification when being created, this is why we need to add extra check to 175 // make sure that the operands of the vector function obtained via VFABI match 176 // the operands of the original vector instruction. 177 if (CI) { 178 for (auto VFParam : OptInfo->Shape.Parameters) { 179 if (VFParam.ParamKind == VFParamKind::GlobalPredicate) 180 continue; 181 182 // tryDemangleForVFABI must return valid ParamPos, otherwise it could be 183 // a bug in the VFABI parser. 184 assert(VFParam.ParamPos < CI->arg_size() && 185 "ParamPos has invalid range."); 186 Type *OrigTy = CI->getArgOperand(VFParam.ParamPos)->getType(); 187 if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) { 188 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName 189 << ". Wrong type at index " << VFParam.ParamPos 190 << ": " << *OrigTy << "\n"); 191 return false; 192 } 193 } 194 } 195 196 FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy); 197 if (!VectorFTy) 198 return false; 199 200 Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy, 201 VD->getVectorFnName(), FuncToReplace); 202 203 replaceWithTLIFunction(I, *OptInfo, TLIFunc); 204 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName 205 << "` with call to `" << TLIFunc->getName() << "`.\n"); 206 ++NumCallsReplaced; 207 return true; 208 } 209 210 /// Supported instruction \p I must be a vectorized frem or a call to an 211 /// intrinsic that returns either void or a vector. 212 static bool isSupportedInstruction(Instruction *I) { 213 Type *Ty = I->getType(); 214 if (auto *CI = dyn_cast<CallInst>(I)) 215 return (Ty->isVectorTy() || Ty->isVoidTy()) && CI->getCalledFunction() && 216 CI->getCalledFunction()->getIntrinsicID() != 217 Intrinsic::not_intrinsic; 218 if (I->getOpcode() == Instruction::FRem && Ty->isVectorTy()) 219 return true; 220 return false; 221 } 222 223 static bool runImpl(const TargetLibraryInfo &TLI, Function &F) { 224 bool Changed = false; 225 SmallVector<Instruction *> ReplacedCalls; 226 for (auto &I : instructions(F)) { 227 if (!isSupportedInstruction(&I)) 228 continue; 229 if (replaceWithCallToVeclib(TLI, I)) { 230 ReplacedCalls.push_back(&I); 231 Changed = true; 232 } 233 } 234 // Erase the calls to the intrinsics that have been replaced 235 // with calls to the vector library. 236 for (auto *CI : ReplacedCalls) 237 CI->eraseFromParent(); 238 return Changed; 239 } 240 241 //////////////////////////////////////////////////////////////////////////////// 242 // New pass manager implementation. 243 //////////////////////////////////////////////////////////////////////////////// 244 PreservedAnalyses ReplaceWithVeclib::run(Function &F, 245 FunctionAnalysisManager &AM) { 246 const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F); 247 auto Changed = runImpl(TLI, F); 248 if (Changed) { 249 LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: " 250 << NumCallsReplaced << "\n"); 251 252 PreservedAnalyses PA; 253 PA.preserveSet<CFGAnalyses>(); 254 PA.preserve<TargetLibraryAnalysis>(); 255 PA.preserve<ScalarEvolutionAnalysis>(); 256 PA.preserve<LoopAccessAnalysis>(); 257 PA.preserve<DemandedBitsAnalysis>(); 258 PA.preserve<OptimizationRemarkEmitterAnalysis>(); 259 return PA; 260 } 261 262 // The pass did not replace any calls, hence it preserves all analyses. 263 return PreservedAnalyses::all(); 264 } 265 266 //////////////////////////////////////////////////////////////////////////////// 267 // Legacy PM Implementation. 268 //////////////////////////////////////////////////////////////////////////////// 269 bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) { 270 const TargetLibraryInfo &TLI = 271 getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 272 return runImpl(TLI, F); 273 } 274 275 void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const { 276 AU.setPreservesCFG(); 277 AU.addRequired<TargetLibraryInfoWrapperPass>(); 278 AU.addPreserved<TargetLibraryInfoWrapperPass>(); 279 AU.addPreserved<ScalarEvolutionWrapperPass>(); 280 AU.addPreserved<AAResultsWrapperPass>(); 281 AU.addPreserved<OptimizationRemarkEmitterWrapperPass>(); 282 AU.addPreserved<GlobalsAAWrapperPass>(); 283 } 284 285 //////////////////////////////////////////////////////////////////////////////// 286 // Legacy Pass manager initialization 287 //////////////////////////////////////////////////////////////////////////////// 288 char ReplaceWithVeclibLegacy::ID = 0; 289 290 INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE, 291 "Replace intrinsics with calls to vector library", false, 292 false) 293 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 294 INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE, 295 "Replace intrinsics with calls to vector library", false, 296 false) 297 298 FunctionPass *llvm::createReplaceWithVeclibLegacyPass() { 299 return new ReplaceWithVeclibLegacy(); 300 } 301