//===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This pass interleaves around sext/zext/trunc instructions. MVE does not have // a single sext/zext or trunc instruction that takes the bottom half of a // vector and extends to a full width, like NEON has with MOVL. Instead it is // expected that this happens through top/bottom instructions. So the MVE // equivalent VMOVLT/B instructions take either the even or odd elements of the // input and extend them to the larger type, producing a vector with half the // number of elements each of double the bitwidth. As there is no simple // instruction, we often have to turn sext/zext/trunc into a series of lane // moves (or stack loads/stores, which we do not do yet). // // This pass takes vector code that starts at truncs, looks for interconnected // blobs of operations that end with sext/zext (or constants/splats) of the // form: // %sa = sext v8i16 %a to v8i32 // %sb = sext v8i16 %b to v8i32 // %add = add v8i32 %sa, %sb // %r = trunc %add to v8i16 // And adds shuffles to allow the use of VMOVL/VMOVN instrctions: // %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7> // %sa = sext v8i16 %sha to v8i32 // %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7> // %sb = sext v8i16 %shb to v8i32 // %add = add v8i32 %sa, %sb // %r = trunc %add to v8i16 // %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7> // Which can then be split and lowered to MVE instructions efficiently: // %sa_b = VMOVLB.s16 %a // %sa_t = VMOVLT.s16 %a // %sb_b = VMOVLB.s16 %b // %sb_t = VMOVLT.s16 %b // %add_b = VADD.i32 %sa_b, %sb_b // %add_t = VADD.i32 %sa_t, %sb_t // %r = VMOVNT.i16 %add_b, %add_t // //===----------------------------------------------------------------------===// #include "ARM.h" #include "ARMBaseInstrInfo.h" #include "ARMSubtarget.h" #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/TargetLowering.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/TargetSubtargetInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsARM.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include #include using namespace llvm; #define DEBUG_TYPE "mve-laneinterleave" cl::opt EnableInterleave( "enable-mve-interleave", cl::Hidden, cl::init(true), cl::desc("Enable interleave MVE vector operation lowering")); namespace { class MVELaneInterleaving : public FunctionPass { public: static char ID; // Pass identification, replacement for typeid explicit MVELaneInterleaving() : FunctionPass(ID) { initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override; StringRef getPassName() const override { return "MVE lane interleaving"; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired(); FunctionPass::getAnalysisUsage(AU); } }; } // end anonymous namespace char MVELaneInterleaving::ID = 0; INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving", false, false) Pass *llvm::createMVELaneInterleavingPass() { return new MVELaneInterleaving(); } static bool isProfitableToInterleave(SmallSetVector &Exts, SmallSetVector &Truncs) { // This is not always beneficial to transform. Exts can be incorporated into // loads, Truncs can be folded into stores. // Truncs are usually the same number of instructions, // VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving // Exts are unfortunately more instructions in the general case: // A=VLDRH.32; B=VLDRH.32; // vs with interleaving: // T=VLDRH.16; A=VMOVNB T; B=VMOVNT T // But those VMOVL may be folded into a VMULL. // But expensive extends/truncs are always good to remove. FPExts always // involve extra VCVT's so are always considered to be beneficial to convert. for (auto *E : Exts) { if (isa(E) || !isa(E->getOperand(0))) { LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n"); return true; } } for (auto *T : Truncs) { if (T->hasOneUse() && !isa(*T->user_begin())) { LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n"); return true; } } // Otherwise, we know we have a load(ext), see if any of the Extends are a // vmull. This is a simple heuristic and certainly not perfect. for (auto *E : Exts) { if (!E->hasOneUse() || cast(*E->user_begin())->getOpcode() != Instruction::Mul) { LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n"); return false; } } return true; } static bool tryInterleave(Instruction *Start, SmallPtrSetImpl &Visited) { LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n"); auto *VT = cast(Start->getType()); if (!isa(Start->getOperand(0))) return false; // Look for connected operations starting from Ext's, terminating at Truncs. std::vector Worklist; Worklist.push_back(Start); Worklist.push_back(cast(Start->getOperand(0))); SmallSetVector Truncs; SmallSetVector Exts; SmallSetVector OtherLeafs; SmallSetVector Ops; while (!Worklist.empty()) { Instruction *I = Worklist.back(); Worklist.pop_back(); switch (I->getOpcode()) { // Truncs case Instruction::Trunc: case Instruction::FPTrunc: if (!Truncs.insert(I)) continue; Visited.insert(I); break; // Extend leafs case Instruction::SExt: case Instruction::ZExt: case Instruction::FPExt: if (Exts.count(I)) continue; for (auto *Use : I->users()) Worklist.push_back(cast(Use)); Exts.insert(I); break; case Instruction::Call: { IntrinsicInst *II = dyn_cast(I); if (!II) return false; switch (II->getIntrinsicID()) { case Intrinsic::abs: case Intrinsic::smin: case Intrinsic::smax: case Intrinsic::umin: case Intrinsic::umax: case Intrinsic::sadd_sat: case Intrinsic::ssub_sat: case Intrinsic::uadd_sat: case Intrinsic::usub_sat: case Intrinsic::minnum: case Intrinsic::maxnum: case Intrinsic::fabs: case Intrinsic::fma: case Intrinsic::ceil: case Intrinsic::floor: case Intrinsic::rint: case Intrinsic::round: case Intrinsic::trunc: break; default: return false; } LLVM_FALLTHROUGH; // Fall through to treating these like an operator below. } // Binary/tertiary ops case Instruction::Add: case Instruction::Sub: case Instruction::Mul: case Instruction::AShr: case Instruction::LShr: case Instruction::Shl: case Instruction::ICmp: case Instruction::FCmp: case Instruction::FAdd: case Instruction::FMul: case Instruction::Select: if (!Ops.insert(I)) continue; for (Use &Op : I->operands()) { if (!isa(Op->getType())) continue; if (isa(Op)) Worklist.push_back(cast(&Op)); else OtherLeafs.insert(&Op); } for (auto *Use : I->users()) Worklist.push_back(cast(Use)); break; case Instruction::ShuffleVector: // A shuffle of a splat is a splat. if (cast(I)->isZeroEltSplat()) continue; LLVM_FALLTHROUGH; default: LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I << "\n"); return false; } } if (Exts.empty() && OtherLeafs.empty()) return false; LLVM_DEBUG({ dbgs() << "Found group:\n Exts:"; for (auto *I : Exts) dbgs() << " " << *I << "\n"; dbgs() << " Ops:"; for (auto *I : Ops) dbgs() << " " << *I << "\n"; dbgs() << " OtherLeafs:"; for (auto *I : OtherLeafs) dbgs() << " " << *I->get() << " of " << *I->getUser() << "\n"; dbgs() << "Truncs:"; for (auto *I : Truncs) dbgs() << " " << *I << "\n"; }); assert(!Truncs.empty() && "Expected some truncs"); // Check types unsigned NumElts = VT->getNumElements(); unsigned BaseElts = VT->getScalarSizeInBits() == 16 ? 8 : (VT->getScalarSizeInBits() == 8 ? 16 : 0); if (BaseElts == 0 || NumElts % BaseElts != 0) { LLVM_DEBUG(dbgs() << " Type is unsupported\n"); return false; } if (Start->getOperand(0)->getType()->getScalarSizeInBits() != VT->getScalarSizeInBits() * 2) { LLVM_DEBUG(dbgs() << " Type not double sized\n"); return false; } for (Instruction *I : Exts) if (I->getOperand(0)->getType() != VT) { LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n"); return false; } for (Instruction *I : Truncs) if (I->getType() != VT) { LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n"); return false; } // Check that it looks beneficial if (!isProfitableToInterleave(Exts, Truncs)) return false; // Create new shuffles around the extends / truncs / other leaves. IRBuilder<> Builder(Start); SmallVector LeafMask; SmallVector TruncMask; // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7 8, 10, 12, 14, 9, 11, 13, 15 // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7 8, 12, 9, 13, 10, 14, 11, 15 for (unsigned Base = 0; Base < NumElts; Base += BaseElts) { for (unsigned i = 0; i < BaseElts / 2; i++) LeafMask.push_back(Base + i * 2); for (unsigned i = 0; i < BaseElts / 2; i++) LeafMask.push_back(Base + i * 2 + 1); } for (unsigned Base = 0; Base < NumElts; Base += BaseElts) { for (unsigned i = 0; i < BaseElts / 2; i++) { TruncMask.push_back(Base + i); TruncMask.push_back(Base + i + BaseElts / 2); } } for (Instruction *I : Exts) { LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n"); Builder.SetInsertPoint(I); Value *Shuffle = Builder.CreateShuffleVector(I->getOperand(0), LeafMask); bool FPext = isa(I); bool Sext = isa(I); Value *Ext = FPext ? Builder.CreateFPExt(Shuffle, I->getType()) : Sext ? Builder.CreateSExt(Shuffle, I->getType()) : Builder.CreateZExt(Shuffle, I->getType()); I->replaceAllUsesWith(Ext); LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n"); } for (Use *I : OtherLeafs) { LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n"); Builder.SetInsertPoint(cast(I->getUser())); Value *Shuffle = Builder.CreateShuffleVector(I->get(), LeafMask); I->getUser()->setOperand(I->getOperandNo(), Shuffle); LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n"); } for (Instruction *I : Truncs) { LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n"); Builder.SetInsertPoint(I->getParent(), ++I->getIterator()); Value *Shuf = Builder.CreateShuffleVector(I, TruncMask); I->replaceAllUsesWith(Shuf); cast(Shuf)->setOperand(0, I); LLVM_DEBUG(dbgs() << " with " << *Shuf << "\n"); } return true; } bool MVELaneInterleaving::runOnFunction(Function &F) { if (!EnableInterleave) return false; auto &TPC = getAnalysis(); auto &TM = TPC.getTM(); auto *ST = &TM.getSubtarget(F); if (!ST->hasMVEIntegerOps()) return false; bool Changed = false; SmallPtrSet Visited; for (Instruction &I : reverse(instructions(F))) { if (I.getType()->isVectorTy() && (isa(I) || isa(I)) && !Visited.count(&I)) Changed |= tryInterleave(&I, Visited); } return Changed; }