1 //===----- SVEIntrinsicOpts - SVE ACLE Intrinsics Opts --------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Performs general IR level optimizations on SVE intrinsics. 10 // 11 // This pass performs the following optimizations: 12 // 13 // - removes unnecessary ptrue intrinsics (llvm.aarch64.sve.ptrue), e.g: 14 // %1 = @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) 15 // %2 = @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) 16 // ; (%1 can be replaced with a reinterpret of %2) 17 // 18 // - optimizes ptest intrinsics where the operands are being needlessly 19 // converted to and from svbool_t. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "AArch64.h" 24 #include "Utils/AArch64BaseInfo.h" 25 #include "llvm/ADT/PostOrderIterator.h" 26 #include "llvm/ADT/SetVector.h" 27 #include "llvm/IR/Constants.h" 28 #include "llvm/IR/Dominators.h" 29 #include "llvm/IR/IRBuilder.h" 30 #include "llvm/IR/Instructions.h" 31 #include "llvm/IR/IntrinsicInst.h" 32 #include "llvm/IR/IntrinsicsAArch64.h" 33 #include "llvm/IR/LLVMContext.h" 34 #include "llvm/IR/Module.h" 35 #include "llvm/IR/PatternMatch.h" 36 #include "llvm/InitializePasses.h" 37 #include "llvm/Support/Debug.h" 38 #include <optional> 39 40 using namespace llvm; 41 using namespace llvm::PatternMatch; 42 43 #define DEBUG_TYPE "aarch64-sve-intrinsic-opts" 44 45 namespace { 46 struct SVEIntrinsicOpts : public ModulePass { 47 static char ID; // Pass identification, replacement for typeid 48 SVEIntrinsicOpts() : ModulePass(ID) { 49 initializeSVEIntrinsicOptsPass(*PassRegistry::getPassRegistry()); 50 } 51 52 bool runOnModule(Module &M) override; 53 void getAnalysisUsage(AnalysisUsage &AU) const override; 54 55 private: 56 bool coalescePTrueIntrinsicCalls(BasicBlock &BB, 57 SmallSetVector<IntrinsicInst *, 4> &PTrues); 58 bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions); 59 bool optimizePredicateStore(Instruction *I); 60 bool optimizePredicateLoad(Instruction *I); 61 62 bool optimizeInstructions(SmallSetVector<Function *, 4> &Functions); 63 64 /// Operates at the function-scope. I.e., optimizations are applied local to 65 /// the functions themselves. 66 bool optimizeFunctions(SmallSetVector<Function *, 4> &Functions); 67 }; 68 } // end anonymous namespace 69 70 void SVEIntrinsicOpts::getAnalysisUsage(AnalysisUsage &AU) const { 71 AU.addRequired<DominatorTreeWrapperPass>(); 72 AU.setPreservesCFG(); 73 } 74 75 char SVEIntrinsicOpts::ID = 0; 76 static const char *name = "SVE intrinsics optimizations"; 77 INITIALIZE_PASS_BEGIN(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false) 78 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass); 79 INITIALIZE_PASS_END(SVEIntrinsicOpts, DEBUG_TYPE, name, false, false) 80 81 ModulePass *llvm::createSVEIntrinsicOptsPass() { 82 return new SVEIntrinsicOpts(); 83 } 84 85 /// Checks if a ptrue intrinsic call is promoted. The act of promoting a 86 /// ptrue will introduce zeroing. For example: 87 /// 88 /// %1 = <vscale x 4 x i1> call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) 89 /// %2 = <vscale x 16 x i1> call @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %1) 90 /// %3 = <vscale x 8 x i1> call @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %2) 91 /// 92 /// %1 is promoted, because it is converted: 93 /// 94 /// <vscale x 4 x i1> => <vscale x 16 x i1> => <vscale x 8 x i1> 95 /// 96 /// via a sequence of the SVE reinterpret intrinsics convert.{to,from}.svbool. 97 static bool isPTruePromoted(IntrinsicInst *PTrue) { 98 // Find all users of this intrinsic that are calls to convert-to-svbool 99 // reinterpret intrinsics. 100 SmallVector<IntrinsicInst *, 4> ConvertToUses; 101 for (User *User : PTrue->users()) { 102 if (match(User, m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>())) { 103 ConvertToUses.push_back(cast<IntrinsicInst>(User)); 104 } 105 } 106 107 // If no such calls were found, this is ptrue is not promoted. 108 if (ConvertToUses.empty()) 109 return false; 110 111 // Otherwise, try to find users of the convert-to-svbool intrinsics that are 112 // calls to the convert-from-svbool intrinsic, and would result in some lanes 113 // being zeroed. 114 const auto *PTrueVTy = cast<ScalableVectorType>(PTrue->getType()); 115 for (IntrinsicInst *ConvertToUse : ConvertToUses) { 116 for (User *User : ConvertToUse->users()) { 117 auto *IntrUser = dyn_cast<IntrinsicInst>(User); 118 if (IntrUser && IntrUser->getIntrinsicID() == 119 Intrinsic::aarch64_sve_convert_from_svbool) { 120 const auto *IntrUserVTy = cast<ScalableVectorType>(IntrUser->getType()); 121 122 // Would some lanes become zeroed by the conversion? 123 if (IntrUserVTy->getElementCount().getKnownMinValue() > 124 PTrueVTy->getElementCount().getKnownMinValue()) 125 // This is a promoted ptrue. 126 return true; 127 } 128 } 129 } 130 131 // If no matching calls were found, this is not a promoted ptrue. 132 return false; 133 } 134 135 /// Attempts to coalesce ptrues in a basic block. 136 bool SVEIntrinsicOpts::coalescePTrueIntrinsicCalls( 137 BasicBlock &BB, SmallSetVector<IntrinsicInst *, 4> &PTrues) { 138 if (PTrues.size() <= 1) 139 return false; 140 141 // Find the ptrue with the most lanes. 142 auto *MostEncompassingPTrue = 143 *llvm::max_element(PTrues, [](auto *PTrue1, auto *PTrue2) { 144 auto *PTrue1VTy = cast<ScalableVectorType>(PTrue1->getType()); 145 auto *PTrue2VTy = cast<ScalableVectorType>(PTrue2->getType()); 146 return PTrue1VTy->getElementCount().getKnownMinValue() < 147 PTrue2VTy->getElementCount().getKnownMinValue(); 148 }); 149 150 // Remove the most encompassing ptrue, as well as any promoted ptrues, leaving 151 // behind only the ptrues to be coalesced. 152 PTrues.remove(MostEncompassingPTrue); 153 PTrues.remove_if(isPTruePromoted); 154 155 // Hoist MostEncompassingPTrue to the start of the basic block. It is always 156 // safe to do this, since ptrue intrinsic calls are guaranteed to have no 157 // predecessors. 158 MostEncompassingPTrue->moveBefore(BB, BB.getFirstInsertionPt()); 159 160 LLVMContext &Ctx = BB.getContext(); 161 IRBuilder<> Builder(Ctx); 162 Builder.SetInsertPoint(&BB, ++MostEncompassingPTrue->getIterator()); 163 164 auto *MostEncompassingPTrueVTy = 165 cast<VectorType>(MostEncompassingPTrue->getType()); 166 auto *ConvertToSVBool = Builder.CreateIntrinsic( 167 Intrinsic::aarch64_sve_convert_to_svbool, {MostEncompassingPTrueVTy}, 168 {MostEncompassingPTrue}); 169 170 bool ConvertFromCreated = false; 171 for (auto *PTrue : PTrues) { 172 auto *PTrueVTy = cast<VectorType>(PTrue->getType()); 173 174 // Only create the converts if the types are not already the same, otherwise 175 // just use the most encompassing ptrue. 176 if (MostEncompassingPTrueVTy != PTrueVTy) { 177 ConvertFromCreated = true; 178 179 Builder.SetInsertPoint(&BB, ++ConvertToSVBool->getIterator()); 180 auto *ConvertFromSVBool = 181 Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool, 182 {PTrueVTy}, {ConvertToSVBool}); 183 PTrue->replaceAllUsesWith(ConvertFromSVBool); 184 } else 185 PTrue->replaceAllUsesWith(MostEncompassingPTrue); 186 187 PTrue->eraseFromParent(); 188 } 189 190 // We never used the ConvertTo so remove it 191 if (!ConvertFromCreated) 192 ConvertToSVBool->eraseFromParent(); 193 194 return true; 195 } 196 197 /// The goal of this function is to remove redundant calls to the SVE ptrue 198 /// intrinsic in each basic block within the given functions. 199 /// 200 /// SVE ptrues have two representations in LLVM IR: 201 /// - a logical representation -- an arbitrary-width scalable vector of i1s, 202 /// i.e. <vscale x N x i1>. 203 /// - a physical representation (svbool, <vscale x 16 x i1>) -- a 16-element 204 /// scalable vector of i1s, i.e. <vscale x 16 x i1>. 205 /// 206 /// The SVE ptrue intrinsic is used to create a logical representation of an SVE 207 /// predicate. Suppose that we have two SVE ptrue intrinsic calls: P1 and P2. If 208 /// P1 creates a logical SVE predicate that is at least as wide as the logical 209 /// SVE predicate created by P2, then all of the bits that are true in the 210 /// physical representation of P2 are necessarily also true in the physical 211 /// representation of P1. P1 'encompasses' P2, therefore, the intrinsic call to 212 /// P2 is redundant and can be replaced by an SVE reinterpret of P1 via 213 /// convert.{to,from}.svbool. 214 /// 215 /// Currently, this pass only coalesces calls to SVE ptrue intrinsics 216 /// if they match the following conditions: 217 /// 218 /// - the call to the intrinsic uses either the SV_ALL or SV_POW2 patterns. 219 /// SV_ALL indicates that all bits of the predicate vector are to be set to 220 /// true. SV_POW2 indicates that all bits of the predicate vector up to the 221 /// largest power-of-two are to be set to true. 222 /// - the result of the call to the intrinsic is not promoted to a wider 223 /// predicate. In this case, keeping the extra ptrue leads to better codegen 224 /// -- coalescing here would create an irreducible chain of SVE reinterprets 225 /// via convert.{to,from}.svbool. 226 /// 227 /// EXAMPLE: 228 /// 229 /// %1 = <vscale x 8 x i1> ptrue(i32 SV_ALL) 230 /// ; Logical: <1, 1, 1, 1, 1, 1, 1, 1> 231 /// ; Physical: <1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0> 232 /// ... 233 /// 234 /// %2 = <vscale x 4 x i1> ptrue(i32 SV_ALL) 235 /// ; Logical: <1, 1, 1, 1> 236 /// ; Physical: <1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0> 237 /// ... 238 /// 239 /// Here, %2 can be replaced by an SVE reinterpret of %1, giving, for instance: 240 /// 241 /// %1 = <vscale x 8 x i1> ptrue(i32 i31) 242 /// %2 = <vscale x 16 x i1> convert.to.svbool(<vscale x 8 x i1> %1) 243 /// %3 = <vscale x 4 x i1> convert.from.svbool(<vscale x 16 x i1> %2) 244 /// 245 bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls( 246 SmallSetVector<Function *, 4> &Functions) { 247 bool Changed = false; 248 249 for (auto *F : Functions) { 250 for (auto &BB : *F) { 251 SmallSetVector<IntrinsicInst *, 4> SVAllPTrues; 252 SmallSetVector<IntrinsicInst *, 4> SVPow2PTrues; 253 254 // For each basic block, collect the used ptrues and try to coalesce them. 255 for (Instruction &I : BB) { 256 if (I.use_empty()) 257 continue; 258 259 auto *IntrI = dyn_cast<IntrinsicInst>(&I); 260 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) 261 continue; 262 263 const auto PTruePattern = 264 cast<ConstantInt>(IntrI->getOperand(0))->getZExtValue(); 265 266 if (PTruePattern == AArch64SVEPredPattern::all) 267 SVAllPTrues.insert(IntrI); 268 if (PTruePattern == AArch64SVEPredPattern::pow2) 269 SVPow2PTrues.insert(IntrI); 270 } 271 272 Changed |= coalescePTrueIntrinsicCalls(BB, SVAllPTrues); 273 Changed |= coalescePTrueIntrinsicCalls(BB, SVPow2PTrues); 274 } 275 } 276 277 return Changed; 278 } 279 280 // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce 281 // scalable stores as late as possible 282 bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) { 283 auto *F = I->getFunction(); 284 auto Attr = F->getFnAttribute(Attribute::VScaleRange); 285 if (!Attr.isValid()) 286 return false; 287 288 unsigned MinVScale = Attr.getVScaleRangeMin(); 289 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax(); 290 // The transform needs to know the exact runtime length of scalable vectors 291 if (!MaxVScale || MinVScale != MaxVScale) 292 return false; 293 294 auto *PredType = 295 ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16); 296 auto *FixedPredType = 297 FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2); 298 299 // If we have a store.. 300 auto *Store = dyn_cast<StoreInst>(I); 301 if (!Store || !Store->isSimple()) 302 return false; 303 304 // ..that is storing a predicate vector sized worth of bits.. 305 if (Store->getOperand(0)->getType() != FixedPredType) 306 return false; 307 308 // ..where the value stored comes from a vector extract.. 309 auto *IntrI = dyn_cast<IntrinsicInst>(Store->getOperand(0)); 310 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_extract) 311 return false; 312 313 // ..that is extracting from index 0.. 314 if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero()) 315 return false; 316 317 // ..where the value being extract from comes from a bitcast 318 auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0)); 319 if (!BitCast) 320 return false; 321 322 // ..and the bitcast is casting from predicate type 323 if (BitCast->getOperand(0)->getType() != PredType) 324 return false; 325 326 IRBuilder<> Builder(I->getContext()); 327 Builder.SetInsertPoint(I); 328 329 Builder.CreateStore(BitCast->getOperand(0), Store->getPointerOperand()); 330 331 Store->eraseFromParent(); 332 if (IntrI->getNumUses() == 0) 333 IntrI->eraseFromParent(); 334 if (BitCast->getNumUses() == 0) 335 BitCast->eraseFromParent(); 336 337 return true; 338 } 339 340 // This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce 341 // scalable loads as late as possible 342 bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) { 343 auto *F = I->getFunction(); 344 auto Attr = F->getFnAttribute(Attribute::VScaleRange); 345 if (!Attr.isValid()) 346 return false; 347 348 unsigned MinVScale = Attr.getVScaleRangeMin(); 349 std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax(); 350 // The transform needs to know the exact runtime length of scalable vectors 351 if (!MaxVScale || MinVScale != MaxVScale) 352 return false; 353 354 auto *PredType = 355 ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16); 356 auto *FixedPredType = 357 FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2); 358 359 // If we have a bitcast.. 360 auto *BitCast = dyn_cast<BitCastInst>(I); 361 if (!BitCast || BitCast->getType() != PredType) 362 return false; 363 364 // ..whose operand is a vector_insert.. 365 auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0)); 366 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::vector_insert) 367 return false; 368 369 // ..that is inserting into index zero of an undef vector.. 370 if (!isa<UndefValue>(IntrI->getOperand(0)) || 371 !cast<ConstantInt>(IntrI->getOperand(2))->isZero()) 372 return false; 373 374 // ..where the value inserted comes from a load.. 375 auto *Load = dyn_cast<LoadInst>(IntrI->getOperand(1)); 376 if (!Load || !Load->isSimple()) 377 return false; 378 379 // ..that is loading a predicate vector sized worth of bits.. 380 if (Load->getType() != FixedPredType) 381 return false; 382 383 IRBuilder<> Builder(I->getContext()); 384 Builder.SetInsertPoint(Load); 385 386 auto *LoadPred = Builder.CreateLoad(PredType, Load->getPointerOperand()); 387 388 BitCast->replaceAllUsesWith(LoadPred); 389 BitCast->eraseFromParent(); 390 if (IntrI->getNumUses() == 0) 391 IntrI->eraseFromParent(); 392 if (Load->getNumUses() == 0) 393 Load->eraseFromParent(); 394 395 return true; 396 } 397 398 bool SVEIntrinsicOpts::optimizeInstructions( 399 SmallSetVector<Function *, 4> &Functions) { 400 bool Changed = false; 401 402 for (auto *F : Functions) { 403 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree(); 404 405 // Traverse the DT with an rpo walk so we see defs before uses, allowing 406 // simplification to be done incrementally. 407 BasicBlock *Root = DT->getRoot(); 408 ReversePostOrderTraversal<BasicBlock *> RPOT(Root); 409 for (auto *BB : RPOT) { 410 for (Instruction &I : make_early_inc_range(*BB)) { 411 switch (I.getOpcode()) { 412 case Instruction::Store: 413 Changed |= optimizePredicateStore(&I); 414 break; 415 case Instruction::BitCast: 416 Changed |= optimizePredicateLoad(&I); 417 break; 418 } 419 } 420 } 421 } 422 423 return Changed; 424 } 425 426 bool SVEIntrinsicOpts::optimizeFunctions( 427 SmallSetVector<Function *, 4> &Functions) { 428 bool Changed = false; 429 430 Changed |= optimizePTrueIntrinsicCalls(Functions); 431 Changed |= optimizeInstructions(Functions); 432 433 return Changed; 434 } 435 436 bool SVEIntrinsicOpts::runOnModule(Module &M) { 437 bool Changed = false; 438 SmallSetVector<Function *, 4> Functions; 439 440 // Check for SVE intrinsic declarations first so that we only iterate over 441 // relevant functions. Where an appropriate declaration is found, store the 442 // function(s) where it is used so we can target these only. 443 for (auto &F : M.getFunctionList()) { 444 if (!F.isDeclaration()) 445 continue; 446 447 switch (F.getIntrinsicID()) { 448 case Intrinsic::vector_extract: 449 case Intrinsic::vector_insert: 450 case Intrinsic::aarch64_sve_ptrue: 451 for (User *U : F.users()) 452 Functions.insert(cast<Instruction>(U)->getFunction()); 453 break; 454 default: 455 break; 456 } 457 } 458 459 if (!Functions.empty()) 460 Changed |= optimizeFunctions(Functions); 461 462 return Changed; 463 } 464