1 //===----- SVEIntrinsicOpts - SVE ACLE Intrinsics Opts --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Performs general IR level optimizations on SVE intrinsics. 10 // 11 // This pass performs the following optimizations: 12 // 13 // - removes unnecessary ptrue intrinsics (llvm.aarch64.sve.ptrue), e.g: 14 // %1 = @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) 15 // %2 = @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) 16 // ; (%1 can be replaced with a reinterpret of %2) 17 // 18 // - optimizes ptest intrinsics where the operands are being needlessly 19 // converted to and from svbool_t. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "AArch64.h" 24 #include "Utils/AArch64BaseInfo.h" 25 #include "llvm/ADT/PostOrderIterator.h" 26 #include "llvm/ADT/SetVector.h" 27 #include "llvm/IR/Constants.h" 28 #include "llvm/IR/Dominators.h" 29 #include "llvm/IR/IRBuilder.h" 30 #include "llvm/IR/Instructions.h" 31 #include "llvm/IR/IntrinsicInst.h" 32 #include "llvm/IR/IntrinsicsAArch64.h" 33 #include "llvm/IR/LLVMContext.h" 34 #include "llvm/IR/Module.h" 35 #include "llvm/IR/PatternMatch.h" 36 #include "llvm/InitializePasses.h" 37 #include <optional> 38 39 using namespace llvm; 40 using namespace llvm::PatternMatch; 41 42 #define DEBUG_TYPE "aarch64-sve-intrinsic-opts" 43 44 namespace { 45 struct SVEIntrinsicOpts : public ModulePass { 46 static char ID; // Pass identification, replacement for typeid 47 SVEIntrinsicOpts() : ModulePass(ID) {} 48 49 bool runOnModule(Module &M) override; 50 void getAnalysisUsage(AnalysisUsage &AU) const override; 51 52 private: 53 bool coalescePTrueIntrinsicCalls(BasicBlock &BB, 54 SmallSetVector<IntrinsicInst *, 4> &PTrues); 55 bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions); 56 bool optimizePredicateStore(Instruction *I); 57 bool optimizePredicateLoad(Instruction *I); 58 59 bool optimizeInstructions(SmallSetVector<Function *, 4> &Functions); 60 61 /// Operates at the function-scope. I.e., optimizations are applied local to 62 /// the functions themselves. 63 bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions); 64 }; 65 } // end anonymous namespace 66 67 void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const { 68 AU.addRequired<DominatorTreeWrapperPass>(); 69 AU.setPreservesCFG(); 70 } 71 72 char SVEIntrinsicOpts::ID = 0; 73 static const char *name = "SVE intrinsics optimizations"; 74 INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false) 75 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass); 76 INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false) 77 78 ModulePass *llvm::createSVEIntrinsicOptsPass() { 79 return new SVEIntrinsicOpts(); 80 } 81 82 /// Checks if a ptrue intrinsic call is promoted. The act of promoting a 83 /// ptrue will introduce zeroing. For example: 84 /// 85 /// %1 = <vscale x 4 x i1> call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) 86 /// %2 = <vscale x 16 x i1> call @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %1) 87 /// %3 = <vscale x 8 x i1> call @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2) 88 /// 89 /// %1 is promoted, because it is converted: 90 /// 91 /// <vscale x 4 x i1> => <vscale x 16 x i1> => <vscale x 8 x i1> 92 /// 93 /// via a sequence of the SVE reinterpret intrinsics convert.{to,from}.svbool. 94 static bool isPTruePromoted(IntrinsicInst *PTrue) { 95 // Find all users of this intrinsic that are calls to convert-to-svbool 96 // reinterpret intrinsics. 97 SmallVector<IntrinsicInst *, 4> ConvertToUses; 98 for (User *User : PTrue->users()) { 99 if (match(User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) { 100 ConvertToUses.push_back(cast<IntrinsicInst>(User)); 101 } 102 } 103 104 // If no such calls were found, this is ptrue is not promoted. 105 if (ConvertToUses.empty()) 106 return false; 107 108 // Otherwise, try to find users of the convert-to-svbool intrinsics that are 109 // calls to the convert-from-svbool intrinsic, and would result in some lanes 110 // being zeroed. 111 const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->getType()); 112 for (IntrinsicInst *ConvertToUse : ConvertToUses) { 113 for (User *User : ConvertToUse->users()) { 114 auto *IntrUser = dyn_cast<IntrinsicInst>(User); 115 if (IntrUser && IntrUser->getIntrinsicID() == 116 Intrinsic::aarch64_sve_convert_from_svbool) { 117 const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType()); 118 119 // Would some lanes become zeroed by the conversion? 120 if (IntrUserVTy->getElementCount().getKnownMinValue() > 121 PTrueVTy->getElementCount().getKnownMinValue()) 122 // This is a promoted ptrue. 123 return true; 124 } 125 } 126 } 127 128 // If no matching calls were found, this is not a promoted ptrue. 129 return false; 130 } 131 132 /// Attempts to coalesce ptrues in a basic block. 133 bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls( 134 BasicBlock &BB, SmallSetVector<IntrinsicInst *, 4> &PTrues) { 135 if (PTrues.size() <= 1) 136 return false; 137 138 // Find the ptrue with the most lanes. 139 auto *MostEncompassingPTrue = 140 *llvm::max_element(PTrues, [](auto *PTrue1, auto *PTrue2) { 141 auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType()); 142 auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType()); 143 return PTrue1VTy->getElementCount().getKnownMinValue() < 144 PTrue2VTy->getElementCount().getKnownMinValue(); 145 }); 146 147 // Remove the most encompassing ptrue, as well as any promoted ptrues, leaving 148 // behind only the ptrues to be coalesced. 149 PTrues.remove(MostEncompassingPTrue); 150 PTrues.remove_if(isPTruePromoted); 151 152 // Hoist MostEncompassingPTrue to the start of the basic block. It is always 153 // safe to do this, since ptrue intrinsic calls are guaranteed to have no 154 // predecessors. 155 MostEncompassingPTrue->moveBefore(BB, BB.getFirstInsertionPt()); 156 157 LLVMContext &Ctx = BB.getContext(); 158 IRBuilder<> Builder(Ctx); 159 Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator()); 160 161 auto *MostEncompassingPTrueVTy = 162 cast<VectorType>(MostEncompassingPTrue->getType()); 163 auto *ConvertToSVBool = Builder.CreateIntrinsic( 164 Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy}, 165 {MostEncompassingPTrue}); 166 167 bool ConvertFromCreated = false; 168 for (auto *PTrue : PTrues) { 169 auto *PTrueVTy = cast<VectorType>(PTrue->getType()); 170 171 // Only create the converts if the types are not already the same, otherwise 172 // just use the most encompassing ptrue. 173 if (MostEncompassingPTrueVTy != PTrueVTy) { 174 ConvertFromCreated = true; 175 176 Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator()); 177 auto *ConvertFromSVBool = 178 Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool, 179 {PTrueVTy}, {ConvertToSVBool}); 180 PTrue->replaceAllUsesWith(ConvertFromSVBool); 181 } else 182 PTrue->replaceAllUsesWith(MostEncompassingPTrue); 183 184 PTrue->eraseFromParent(); 185 } 186 187 // We never used the ConvertTo so remove it 188 if (!ConvertFromCreated) 189 ConvertToSVBool->eraseFromParent(); 190 191 return true; 192 } 193 194 /// The goal of this function is to remove redundant calls to the SVE ptrue 195 /// intrinsic in each basic block within the given functions. 196 /// 197 /// SVE ptrues have two representations in LLVM IR: 198 /// - a logical representation -- an arbitrary-width scalable vector of i1s, 199 /// i.e. <vscale x N x i1>. 200 /// - a physical representation (svbool, <vscale x 16 x i1>) -- a 16-element 201 /// scalable vector of i1s, i.e. <vscale x 16 x i1>. 202 /// 203 /// The SVE ptrue intrinsic is used to create a logical representation of an SVE 204 /// predicate. Suppose that we have two SVE ptrue intrinsic calls: P1 and P2. If 205 /// P1 creates a logical SVE predicate that is at least as wide as the logical 206 /// SVE predicate created by P2, then all of the bits that are true in the 207 /// physical representation of P2 are necessarily also true in the physical 208 /// representation of P1. P1 'encompasses' P2, therefore, the intrinsic call to 209 /// P2 is redundant and can be replaced by an SVE reinterpret of P1 via 210 /// convert.{to,from}.svbool. 211 /// 212 /// Currently, this pass only coalesces calls to SVE ptrue intrinsics 213 /// if they match the following conditions: 214 /// 215 /// - the call to the intrinsic uses either the SV_ALL or SV_POW2 patterns. 216 /// SV_ALL indicates that all bits of the predicate vector are to be set to 217 /// true. SV_POW2 indicates that all bits of the predicate vector up to the 218 /// largest power-of-two are to be set to true. 219 /// - the result of the call to the intrinsic is not promoted to a wider 220 /// predicate. In this case, keeping the extra ptrue leads to better codegen 221 /// -- coalescing here would create an irreducible chain of SVE reinterprets 222 /// via convert.{to,from}.svbool. 223 /// 224 /// EXAMPLE: 225 /// 226 /// %1 = <vscale x 8 x i1> ptrue(i32 SV_ALL) 227 /// ; Logical: <1, 1, 1, 1, 1, 1, 1, 1> 228 /// ; Physical: <1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0> 229 /// ... 230 /// 231 /// %2 = <vscale x 4 x i1> ptrue(i32 SV_ALL) 232 /// ; Logical: <1, 1, 1, 1> 233 /// ; Physical: <1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0> 234 /// ... 235 /// 236 /// Here, %2 can be replaced by an SVE reinterpret of %1, giving, for instance: 237 /// 238 /// %1 = <vscale x 8 x i1> ptrue(i32 i31) 239 /// %2 = <vscale x 16 x i1> convert.to.svbool(<vscale x 8 x i1> %1) 240 /// %3 = <vscale x 4 x i1> convert.from.svbool(<vscale x 16 x i1> %2) 241 /// 242 bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls( 243 SmallSetVector<Function *, 4> &Functions) { 244 bool Changed = false; 245 246 for (auto *F : Functions) { 247 for (auto &BB : *F) { 248 SmallSetVector<IntrinsicInst *, 4> SVAllPTrues; 249 SmallSetVector<IntrinsicInst *, 4> SVPow2PTrues; 250 251 // For each basic block, collect the used ptrues and try to coalesce them. 252 for (Instruction &I : BB) { 253 if (I.use_empty()) 254 continue; 255 256 auto *IntrI = dyn_cast<IntrinsicInst>(&I); 257 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) 258 continue; 259 260 const auto PTruePattern = 261 cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue(); 262 263 if (PTruePattern == AArch64SVEPredPattern::all) 264 SVAllPTrues.insert(IntrI); 265 if (PTruePattern == AArch64SVEPredPattern::pow2) 266 SVPow2PTrues.insert(IntrI); 267 } 268 269 Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues); 270 Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues); 271 } 272 } 273 274 return Changed; 275 } 276 277 // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce 278 // scalable stores as late as possible 279 bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) { 280 auto *F = I->getFunction(); 281 auto Attr = F->getFnAttribute(Attribute::VScaleRange); 282 if (!Attr.isValid()) 283 return false; 284 285 unsigned MinVScale = Attr.getVScaleRangeMin(); 286 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax(); 287 // The transform needs to know the exact runtime length of scalable vectors 288 if (!MaxVScale || MinVScale != MaxVScale) 289 return false; 290 291 auto *PredType = 292 ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16); 293 auto *FixedPredType = 294 FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2); 295 296 // If we have a store.. 297 auto *Store = dyn_cast<StoreInst>(I); 298 if (!Store || !Store->isSimple()) 299 return false; 300 301 // ..that is storing a predicate vector sized worth of bits.. 302 if (Store->getOperand(0)->getType() != FixedPredType) 303 return false; 304 305 // ..where the value stored comes from a vector extract.. 306 auto *IntrI = dyn_cast<IntrinsicInst>(Store->getOperand(0)); 307 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract) 308 return false; 309 310 // ..that is extracting from index 0.. 311 if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero()) 312 return false; 313 314 // ..where the value being extract from comes from a bitcast 315 auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0)); 316 if (!BitCast) 317 return false; 318 319 // ..and the bitcast is casting from predicate type 320 if (BitCast->getOperand(0)->getType() != PredType) 321 return false; 322 323 IRBuilder<> Builder(I->getContext()); 324 Builder.SetInsertPoint(I); 325 326 Builder.CreateStore(BitCast->getOperand(0), Store->getPointerOperand()); 327 328 Store->eraseFromParent(); 329 if (IntrI->use_empty()) 330 IntrI->eraseFromParent(); 331 if (BitCast->use_empty()) 332 BitCast->eraseFromParent(); 333 334 return true; 335 } 336 337 // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce 338 // scalable loads as late as possible 339 bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) { 340 auto *F = I->getFunction(); 341 auto Attr = F->getFnAttribute(Attribute::VScaleRange); 342 if (!Attr.isValid()) 343 return false; 344 345 unsigned MinVScale = Attr.getVScaleRangeMin(); 346 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax(); 347 // The transform needs to know the exact runtime length of scalable vectors 348 if (!MaxVScale || MinVScale != MaxVScale) 349 return false; 350 351 auto *PredType = 352 ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16); 353 auto *FixedPredType = 354 FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2); 355 356 // If we have a bitcast.. 357 auto *BitCast = dyn_cast<BitCastInst>(I); 358 if (!BitCast || BitCast->getType() != PredType) 359 return false; 360 361 // ..whose operand is a vector_insert.. 362 auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0)); 363 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert) 364 return false; 365 366 // ..that is inserting into index zero of an undef vector.. 367 if (!isa<UndefValue>(IntrI->getOperand(0)) || 368 !cast<ConstantInt>(IntrI->getOperand(2))->isZero()) 369 return false; 370 371 // ..where the value inserted comes from a load.. 372 auto *Load = dyn_cast<LoadInst>(IntrI->getOperand(1)); 373 if (!Load || !Load->isSimple()) 374 return false; 375 376 // ..that is loading a predicate vector sized worth of bits.. 377 if (Load->getType() != FixedPredType) 378 return false; 379 380 IRBuilder<> Builder(I->getContext()); 381 Builder.SetInsertPoint(Load); 382 383 auto *LoadPred = Builder.CreateLoad(PredType, Load->getPointerOperand()); 384 385 BitCast->replaceAllUsesWith(LoadPred); 386 BitCast->eraseFromParent(); 387 if (IntrI->use_empty()) 388 IntrI->eraseFromParent(); 389 if (Load->use_empty()) 390 Load->eraseFromParent(); 391 392 return true; 393 } 394 395 bool SVEIntrinsicOpts::optimizeInstructions( 396 SmallSetVector<Function *, 4> &Functions) { 397 bool Changed = false; 398 399 for (auto *F : Functions) { 400 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree(); 401 402 // Traverse the DT with an rpo walk so we see defs before uses, allowing 403 // simplification to be done incrementally. 404 BasicBlock *Root = DT->getRoot(); 405 ReversePostOrderTraversal<BasicBlock *> RPOT(Root); 406 for (auto *BB : RPOT) { 407 for (Instruction &I : make_early_inc_range(*BB)) { 408 switch (I.getOpcode()) { 409 case Instruction::Store: 410 Changed |= optimizePredicateStore(&I); 411 break; 412 case Instruction::BitCast: 413 Changed |= optimizePredicateLoad(&I); 414 break; 415 } 416 } 417 } 418 } 419 420 return Changed; 421 } 422 423 bool SVEIntrinsicOpts::optimizeFunctions( 424 SmallSetVector<Function *, 4> &Functions) { 425 bool Changed = false; 426 427 Changed |= optimizePTrueIntrinsicCalls(Functions); 428 Changed |= optimizeInstructions(Functions); 429 430 return Changed; 431 } 432 433 bool SVEIntrinsicOpts::runOnModule(Module &M) { 434 bool Changed = false; 435 SmallSetVector<Function *, 4> Functions; 436 437 // Check for SVE intrinsic declarations first so that we only iterate over 438 // relevant functions. Where an appropriate declaration is found, store the 439 // function(s) where it is used so we can target these only. 440 for (auto &F : M.getFunctionList()) { 441 if (!F.isDeclaration()) 442 continue; 443 444 switch (F.getIntrinsicID()) { 445 case Intrinsic::vector_extract: 446 case Intrinsic::vector_insert: 447 case Intrinsic::aarch64_sve_ptrue: 448 for (User *U : F.users()) 449 Functions.insert(cast<Instruction>(U)->getFunction()); 450 break; 451 default: 452 break; 453 } 454 } 455 456 if (!Functions.empty()) 457 Changed |= optimizeFunctions(Functions); 458 459 return Changed; 460 } 461