1 //===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass interleaves around sext/zext/trunc instructions. MVE does not have 10 // a single sext/zext or trunc instruction that takes the bottom half of a 11 // vector and extends to a full width, like NEON has with MOVL. Instead it is 12 // expected that this happens through top/bottom instructions. So the MVE 13 // equivalent VMOVLT/B instructions take either the even or odd elements of the 14 // input and extend them to the larger type, producing a vector with half the 15 // number of elements each of double the bitwidth. As there is no simple 16 // instruction, we often have to turn sext/zext/trunc into a series of lane 17 // moves (or stack loads/stores, which we do not do yet). 18 // 19 // This pass takes vector code that starts at truncs, looks for interconnected 20 // blobs of operations that end with sext/zext (or constants/splats) of the 21 // form: 22 // %sa = sext v8i16 %a to v8i32 23 // %sb = sext v8i16 %b to v8i32 24 // %add = add v8i32 %sa, %sb 25 // %r = trunc %add to v8i16 26 // And adds shuffles to allow the use of VMOVL/VMOVN instrctions: 27 // %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7> 28 // %sa = sext v8i16 %sha to v8i32 29 // %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7> 30 // %sb = sext v8i16 %shb to v8i32 31 // %add = add v8i32 %sa, %sb 32 // %r = trunc %add to v8i16 33 // %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7> 34 // Which can then be split and lowered to MVE instructions efficiently: 35 // %sa_b = VMOVLB.s16 %a 36 // %sa_t = VMOVLT.s16 %a 37 // %sb_b = VMOVLB.s16 %b 38 // %sb_t = VMOVLT.s16 %b 39 // %add_b = VADD.i32 %sa_b, %sb_b 40 // %add_t = VADD.i32 %sa_t, %sb_t 41 // %r = VMOVNT.i16 %add_b, %add_t 42 // 43 //===----------------------------------------------------------------------===// 44 45 #include "ARM.h" 46 #include "ARMBaseInstrInfo.h" 47 #include "ARMSubtarget.h" 48 #include "llvm/ADT/SetVector.h" 49 #include "llvm/Analysis/TargetTransformInfo.h" 50 #include "llvm/CodeGen/TargetLowering.h" 51 #include "llvm/CodeGen/TargetPassConfig.h" 52 #include "llvm/CodeGen/TargetSubtargetInfo.h" 53 #include "llvm/IR/BasicBlock.h" 54 #include "llvm/IR/Constant.h" 55 #include "llvm/IR/Constants.h" 56 #include "llvm/IR/DerivedTypes.h" 57 #include "llvm/IR/Function.h" 58 #include "llvm/IR/IRBuilder.h" 59 #include "llvm/IR/InstIterator.h" 60 #include "llvm/IR/InstrTypes.h" 61 #include "llvm/IR/Instruction.h" 62 #include "llvm/IR/Instructions.h" 63 #include "llvm/IR/IntrinsicInst.h" 64 #include "llvm/IR/Intrinsics.h" 65 #include "llvm/IR/IntrinsicsARM.h" 66 #include "llvm/IR/PatternMatch.h" 67 #include "llvm/IR/Type.h" 68 #include "llvm/IR/Value.h" 69 #include "llvm/InitializePasses.h" 70 #include "llvm/Pass.h" 71 #include "llvm/Support/Casting.h" 72 #include <algorithm> 73 #include <cassert> 74 75 using namespace llvm; 76 77 #define DEBUG_TYPE "mve-laneinterleave" 78 79 cl::opt<bool> EnableInterleave( 80 "enable-mve-interleave", cl::Hidden, cl::init(true), 81 cl::desc("Enable interleave MVE vector operation lowering")); 82 83 namespace { 84 85 class MVELaneInterleaving : public FunctionPass { 86 public: 87 static char ID; // Pass identification, replacement for typeid 88 89 explicit MVELaneInterleaving() : FunctionPass(ID) { 90 initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry()); 91 } 92 93 bool runOnFunction(Function &F) override; 94 95 StringRef getPassName() const override { return "MVE lane interleaving"; } 96 97 void getAnalysisUsage(AnalysisUsage &AU) const override { 98 AU.setPreservesCFG(); 99 AU.addRequired<TargetPassConfig>(); 100 FunctionPass::getAnalysisUsage(AU); 101 } 102 }; 103 104 } // end anonymous namespace 105 106 char MVELaneInterleaving::ID = 0; 107 108 INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving", false, 109 false) 110 111 Pass *llvm::createMVELaneInterleavingPass() { 112 return new MVELaneInterleaving(); 113 } 114 115 static bool isProfitableToInterleave(SmallSetVector<Instruction *, 4> &Exts, 116 SmallSetVector<Instruction *, 4> &Truncs) { 117 // This is not always beneficial to transform. Exts can be incorporated into 118 // loads, Truncs can be folded into stores. 119 // Truncs are usually the same number of instructions, 120 // VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving 121 // Exts are unfortunately more instructions in the general case: 122 // A=VLDRH.32; B=VLDRH.32; 123 // vs with interleaving: 124 // T=VLDRH.16; A=VMOVNB T; B=VMOVNT T 125 // But those VMOVL may be folded into a VMULL. 126 127 // But expensive extends/truncs are always good to remove. FPExts always 128 // involve extra VCVT's so are always considered to be beneficial to convert. 129 for (auto *E : Exts) { 130 if (isa<FPExtInst>(E) || !isa<LoadInst>(E->getOperand(0))) { 131 LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n"); 132 return true; 133 } 134 } 135 for (auto *T : Truncs) { 136 if (T->hasOneUse() && !isa<StoreInst>(*T->user_begin())) { 137 LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n"); 138 return true; 139 } 140 } 141 142 // Otherwise, we know we have a load(ext), see if any of the Extends are a 143 // vmull. This is a simple heuristic and certainly not perfect. 144 for (auto *E : Exts) { 145 if (!E->hasOneUse() || 146 cast<Instruction>(*E->user_begin())->getOpcode() != Instruction::Mul) { 147 LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n"); 148 return false; 149 } 150 } 151 return true; 152 } 153 154 static bool tryInterleave(Instruction *Start, 155 SmallPtrSetImpl<Instruction *> &Visited) { 156 LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n"); 157 158 if (!isa<Instruction>(Start->getOperand(0))) 159 return false; 160 161 // Look for connected operations starting from Ext's, terminating at Truncs. 162 std::vector<Instruction *> Worklist; 163 Worklist.push_back(Start); 164 Worklist.push_back(cast<Instruction>(Start->getOperand(0))); 165 166 SmallSetVector<Instruction *, 4> Truncs; 167 SmallSetVector<Instruction *, 4> Reducts; 168 SmallSetVector<Instruction *, 4> Exts; 169 SmallSetVector<Use *, 4> OtherLeafs; 170 SmallSetVector<Instruction *, 4> Ops; 171 172 while (!Worklist.empty()) { 173 Instruction *I = Worklist.back(); 174 Worklist.pop_back(); 175 176 switch (I->getOpcode()) { 177 // Truncs 178 case Instruction::Trunc: 179 case Instruction::FPTrunc: 180 if (!Truncs.insert(I)) 181 continue; 182 Visited.insert(I); 183 break; 184 185 // Extend leafs 186 case Instruction::SExt: 187 case Instruction::ZExt: 188 case Instruction::FPExt: 189 if (Exts.count(I)) 190 continue; 191 for (auto *Use : I->users()) 192 Worklist.push_back(cast<Instruction>(Use)); 193 Exts.insert(I); 194 break; 195 196 case Instruction::Call: { 197 IntrinsicInst *II = dyn_cast<IntrinsicInst>(I); 198 if (!II) 199 return false; 200 201 if (II->getIntrinsicID() == Intrinsic::vector_reduce_add) { 202 if (!Reducts.insert(I)) 203 continue; 204 Visited.insert(I); 205 break; 206 } 207 208 switch (II->getIntrinsicID()) { 209 case Intrinsic::abs: 210 case Intrinsic::smin: 211 case Intrinsic::smax: 212 case Intrinsic::umin: 213 case Intrinsic::umax: 214 case Intrinsic::sadd_sat: 215 case Intrinsic::ssub_sat: 216 case Intrinsic::uadd_sat: 217 case Intrinsic::usub_sat: 218 case Intrinsic::minnum: 219 case Intrinsic::maxnum: 220 case Intrinsic::fabs: 221 case Intrinsic::fma: 222 case Intrinsic::ceil: 223 case Intrinsic::floor: 224 case Intrinsic::rint: 225 case Intrinsic::round: 226 case Intrinsic::trunc: 227 break; 228 default: 229 return false; 230 } 231 [[fallthrough]]; // Fall through to treating these like an operator below. 232 } 233 // Binary/tertiary ops 234 case Instruction::Add: 235 case Instruction::Sub: 236 case Instruction::Mul: 237 case Instruction::AShr: 238 case Instruction::LShr: 239 case Instruction::Shl: 240 case Instruction::ICmp: 241 case Instruction::FCmp: 242 case Instruction::FAdd: 243 case Instruction::FMul: 244 case Instruction::Select: 245 if (!Ops.insert(I)) 246 continue; 247 248 for (Use &Op : I->operands()) { 249 if (!isa<FixedVectorType>(Op->getType())) 250 continue; 251 if (isa<Instruction>(Op)) 252 Worklist.push_back(cast<Instruction>(&Op)); 253 else 254 OtherLeafs.insert(&Op); 255 } 256 257 for (auto *Use : I->users()) 258 Worklist.push_back(cast<Instruction>(Use)); 259 break; 260 261 case Instruction::ShuffleVector: 262 // A shuffle of a splat is a splat. 263 if (cast<ShuffleVectorInst>(I)->isZeroEltSplat()) 264 continue; 265 [[fallthrough]]; 266 267 default: 268 LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I << "\n"); 269 return false; 270 } 271 } 272 273 if (Exts.empty() && OtherLeafs.empty()) 274 return false; 275 276 LLVM_DEBUG({ 277 dbgs() << "Found group:\n Exts:\n"; 278 for (auto *I : Exts) 279 dbgs() << " " << *I << "\n"; 280 dbgs() << " Ops:\n"; 281 for (auto *I : Ops) 282 dbgs() << " " << *I << "\n"; 283 dbgs() << " OtherLeafs:\n"; 284 for (auto *I : OtherLeafs) 285 dbgs() << " " << *I->get() << " of " << *I->getUser() << "\n"; 286 dbgs() << " Truncs:\n"; 287 for (auto *I : Truncs) 288 dbgs() << " " << *I << "\n"; 289 dbgs() << " Reducts:\n"; 290 for (auto *I : Reducts) 291 dbgs() << " " << *I << "\n"; 292 }); 293 294 assert((!Truncs.empty() || !Reducts.empty()) && 295 "Expected some truncs or reductions"); 296 if (Truncs.empty() && Exts.empty()) 297 return false; 298 299 auto *VT = !Truncs.empty() 300 ? cast<FixedVectorType>(Truncs[0]->getType()) 301 : cast<FixedVectorType>(Exts[0]->getOperand(0)->getType()); 302 LLVM_DEBUG(dbgs() << "Using VT:" << *VT << "\n"); 303 304 // Check types 305 unsigned NumElts = VT->getNumElements(); 306 unsigned BaseElts = VT->getScalarSizeInBits() == 16 307 ? 8 308 : (VT->getScalarSizeInBits() == 8 ? 16 : 0); 309 if (BaseElts == 0 || NumElts % BaseElts != 0) { 310 LLVM_DEBUG(dbgs() << " Type is unsupported\n"); 311 return false; 312 } 313 if (Start->getOperand(0)->getType()->getScalarSizeInBits() != 314 VT->getScalarSizeInBits() * 2) { 315 LLVM_DEBUG(dbgs() << " Type not double sized\n"); 316 return false; 317 } 318 for (Instruction *I : Exts) 319 if (I->getOperand(0)->getType() != VT) { 320 LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n"); 321 return false; 322 } 323 for (Instruction *I : Truncs) 324 if (I->getType() != VT) { 325 LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n"); 326 return false; 327 } 328 329 // Check that it looks beneficial 330 if (!isProfitableToInterleave(Exts, Truncs)) 331 return false; 332 if (!Reducts.empty() && (Ops.empty() || all_of(Ops, [](Instruction *I) { 333 return I->getOpcode() == Instruction::Mul || 334 I->getOpcode() == Instruction::Select || 335 I->getOpcode() == Instruction::ICmp; 336 }))) { 337 LLVM_DEBUG(dbgs() << "Reduction does not look profitable\n"); 338 return false; 339 } 340 341 // Create new shuffles around the extends / truncs / other leaves. 342 IRBuilder<> Builder(Start); 343 344 SmallVector<int, 16> LeafMask; 345 SmallVector<int, 16> TruncMask; 346 // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7 8, 10, 12, 14, 9, 11, 13, 15 347 // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7 8, 12, 9, 13, 10, 14, 11, 15 348 for (unsigned Base = 0; Base < NumElts; Base += BaseElts) { 349 for (unsigned i = 0; i < BaseElts / 2; i++) 350 LeafMask.push_back(Base + i * 2); 351 for (unsigned i = 0; i < BaseElts / 2; i++) 352 LeafMask.push_back(Base + i * 2 + 1); 353 } 354 for (unsigned Base = 0; Base < NumElts; Base += BaseElts) { 355 for (unsigned i = 0; i < BaseElts / 2; i++) { 356 TruncMask.push_back(Base + i); 357 TruncMask.push_back(Base + i + BaseElts / 2); 358 } 359 } 360 361 for (Instruction *I : Exts) { 362 LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n"); 363 Builder.SetInsertPoint(I); 364 Value *Shuffle = Builder.CreateShuffleVector(I->getOperand(0), LeafMask); 365 bool FPext = isa<FPExtInst>(I); 366 bool Sext = isa<SExtInst>(I); 367 Value *Ext = FPext ? Builder.CreateFPExt(Shuffle, I->getType()) 368 : Sext ? Builder.CreateSExt(Shuffle, I->getType()) 369 : Builder.CreateZExt(Shuffle, I->getType()); 370 I->replaceAllUsesWith(Ext); 371 LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n"); 372 } 373 374 for (Use *I : OtherLeafs) { 375 LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n"); 376 Builder.SetInsertPoint(cast<Instruction>(I->getUser())); 377 Value *Shuffle = Builder.CreateShuffleVector(I->get(), LeafMask); 378 I->getUser()->setOperand(I->getOperandNo(), Shuffle); 379 LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n"); 380 } 381 382 for (Instruction *I : Truncs) { 383 LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n"); 384 385 Builder.SetInsertPoint(I->getParent(), ++I->getIterator()); 386 Value *Shuf = Builder.CreateShuffleVector(I, TruncMask); 387 I->replaceAllUsesWith(Shuf); 388 cast<Instruction>(Shuf)->setOperand(0, I); 389 390 LLVM_DEBUG(dbgs() << " with " << *Shuf << "\n"); 391 } 392 393 return true; 394 } 395 396 // Add reductions are fairly common and associative, meaning we can start the 397 // interleaving from them and don't need to emit a shuffle. 398 static bool isAddReduction(Instruction &I) { 399 if (auto *II = dyn_cast<IntrinsicInst>(&I)) 400 return II->getIntrinsicID() == Intrinsic::vector_reduce_add; 401 return false; 402 } 403 404 bool MVELaneInterleaving::runOnFunction(Function &F) { 405 if (!EnableInterleave) 406 return false; 407 auto &TPC = getAnalysis<TargetPassConfig>(); 408 auto &TM = TPC.getTM<TargetMachine>(); 409 auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 410 if (!ST->hasMVEIntegerOps()) 411 return false; 412 413 bool Changed = false; 414 415 SmallPtrSet<Instruction *, 16> Visited; 416 for (Instruction &I : reverse(instructions(F))) { 417 if (((I.getType()->isVectorTy() && 418 (isa<TruncInst>(I) || isa<FPTruncInst>(I))) || 419 isAddReduction(I)) && 420 !Visited.count(&I)) 421 Changed |= tryInterleave(&I, Visited); 422 } 423 424 return Changed; 425 } 426