xref: /freebsd/contrib/llvm-project/llvm/lib/Target/Hexagon/HexagonGenExtract.cpp (revision 924226fba12cc9a228c73b956e1b7fa24c60b055)
1 //===- HexagonGenExtract.cpp ----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ADT/APInt.h"
10 #include "llvm/ADT/GraphTraits.h"
11 #include "llvm/IR/BasicBlock.h"
12 #include "llvm/IR/CFG.h"
13 #include "llvm/IR/Constants.h"
14 #include "llvm/IR/Dominators.h"
15 #include "llvm/IR/Function.h"
16 #include "llvm/IR/IRBuilder.h"
17 #include "llvm/IR/Instruction.h"
18 #include "llvm/IR/Instructions.h"
19 #include "llvm/IR/Intrinsics.h"
20 #include "llvm/IR/IntrinsicsHexagon.h"
21 #include "llvm/IR/PatternMatch.h"
22 #include "llvm/IR/Type.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/CommandLine.h"
27 #include <algorithm>
28 #include <cstdint>
29 #include <iterator>
30 
31 using namespace llvm;
32 
33 static cl::opt<unsigned> ExtractCutoff("extract-cutoff", cl::init(~0U),
34   cl::Hidden, cl::desc("Cutoff for generating \"extract\""
35   " instructions"));
36 
37 // This prevents generating extract instructions that have the offset of 0.
38 // One of the reasons for "extract" is to put a sequence of bits in a regis-
39 // ter, starting at offset 0 (so that these bits can then be used by an
40 // "insert"). If the bits are already at offset 0, it is better not to gene-
41 // rate "extract", since logical bit operations can be merged into compound
42 // instructions (as opposed to "extract").
43 static cl::opt<bool> NoSR0("extract-nosr0", cl::init(true), cl::Hidden,
44   cl::desc("No extract instruction with offset 0"));
45 
46 static cl::opt<bool> NeedAnd("extract-needand", cl::init(true), cl::Hidden,
47   cl::desc("Require & in extract patterns"));
48 
49 namespace llvm {
50 
51 void initializeHexagonGenExtractPass(PassRegistry&);
52 FunctionPass *createHexagonGenExtract();
53 
54 } // end namespace llvm
55 
56 namespace {
57 
58   class HexagonGenExtract : public FunctionPass {
59   public:
60     static char ID;
61 
62     HexagonGenExtract() : FunctionPass(ID) {
63       initializeHexagonGenExtractPass(*PassRegistry::getPassRegistry());
64     }
65 
66     StringRef getPassName() const override {
67       return "Hexagon generate \"extract\" instructions";
68     }
69 
70     bool runOnFunction(Function &F) override;
71 
72     void getAnalysisUsage(AnalysisUsage &AU) const override {
73       AU.addRequired<DominatorTreeWrapperPass>();
74       AU.addPreserved<DominatorTreeWrapperPass>();
75       FunctionPass::getAnalysisUsage(AU);
76     }
77 
78   private:
79     bool visitBlock(BasicBlock *B);
80     bool convert(Instruction *In);
81 
82     unsigned ExtractCount = 0;
83     DominatorTree *DT;
84   };
85 
86 } // end anonymous namespace
87 
88 char HexagonGenExtract::ID = 0;
89 
90 INITIALIZE_PASS_BEGIN(HexagonGenExtract, "hextract", "Hexagon generate "
91   "\"extract\" instructions", false, false)
92 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
93 INITIALIZE_PASS_END(HexagonGenExtract, "hextract", "Hexagon generate "
94   "\"extract\" instructions", false, false)
95 
96 bool HexagonGenExtract::convert(Instruction *In) {
97   using namespace PatternMatch;
98 
99   Value *BF = nullptr;
100   ConstantInt *CSL = nullptr, *CSR = nullptr, *CM = nullptr;
101   BasicBlock *BB = In->getParent();
102   LLVMContext &Ctx = BB->getContext();
103   bool LogicalSR;
104 
105   // (and (shl (lshr x, #sr), #sl), #m)
106   LogicalSR = true;
107   bool Match = match(In, m_And(m_Shl(m_LShr(m_Value(BF), m_ConstantInt(CSR)),
108                                m_ConstantInt(CSL)),
109                          m_ConstantInt(CM)));
110 
111   if (!Match) {
112     // (and (shl (ashr x, #sr), #sl), #m)
113     LogicalSR = false;
114     Match = match(In, m_And(m_Shl(m_AShr(m_Value(BF), m_ConstantInt(CSR)),
115                             m_ConstantInt(CSL)),
116                       m_ConstantInt(CM)));
117   }
118   if (!Match) {
119     // (and (shl x, #sl), #m)
120     LogicalSR = true;
121     CSR = ConstantInt::get(Type::getInt32Ty(Ctx), 0);
122     Match = match(In, m_And(m_Shl(m_Value(BF), m_ConstantInt(CSL)),
123                       m_ConstantInt(CM)));
124     if (Match && NoSR0)
125       return false;
126   }
127   if (!Match) {
128     // (and (lshr x, #sr), #m)
129     LogicalSR = true;
130     CSL = ConstantInt::get(Type::getInt32Ty(Ctx), 0);
131     Match = match(In, m_And(m_LShr(m_Value(BF), m_ConstantInt(CSR)),
132                             m_ConstantInt(CM)));
133   }
134   if (!Match) {
135     // (and (ashr x, #sr), #m)
136     LogicalSR = false;
137     CSL = ConstantInt::get(Type::getInt32Ty(Ctx), 0);
138     Match = match(In, m_And(m_AShr(m_Value(BF), m_ConstantInt(CSR)),
139                             m_ConstantInt(CM)));
140   }
141   if (!Match) {
142     CM = nullptr;
143     // (shl (lshr x, #sr), #sl)
144     LogicalSR = true;
145     Match = match(In, m_Shl(m_LShr(m_Value(BF), m_ConstantInt(CSR)),
146                             m_ConstantInt(CSL)));
147   }
148   if (!Match) {
149     CM = nullptr;
150     // (shl (ashr x, #sr), #sl)
151     LogicalSR = false;
152     Match = match(In, m_Shl(m_AShr(m_Value(BF), m_ConstantInt(CSR)),
153                             m_ConstantInt(CSL)));
154   }
155   if (!Match)
156     return false;
157 
158   Type *Ty = BF->getType();
159   if (!Ty->isIntegerTy())
160     return false;
161   unsigned BW = Ty->getPrimitiveSizeInBits();
162   if (BW != 32 && BW != 64)
163     return false;
164 
165   uint32_t SR = CSR->getZExtValue();
166   uint32_t SL = CSL->getZExtValue();
167 
168   if (!CM) {
169     // If there was no and, and the shift left did not remove all potential
170     // sign bits created by the shift right, then extractu cannot reproduce
171     // this value.
172     if (!LogicalSR && (SR > SL))
173       return false;
174     APInt A = APInt(BW, ~0ULL).lshr(SR).shl(SL);
175     CM = ConstantInt::get(Ctx, A);
176   }
177 
178   // CM is the shifted-left mask. Shift it back right to remove the zero
179   // bits on least-significant positions.
180   APInt M = CM->getValue().lshr(SL);
181   uint32_t T = M.countTrailingOnes();
182 
183   // During the shifts some of the bits will be lost. Calculate how many
184   // of the original value will remain after shift right and then left.
185   uint32_t U = BW - std::max(SL, SR);
186   // The width of the extracted field is the minimum of the original bits
187   // that remain after the shifts and the number of contiguous 1s in the mask.
188   uint32_t W = std::min(U, T);
189   if (W == 0 || W == 1)
190     return false;
191 
192   // Check if the extracted bits are contained within the mask that it is
193   // and-ed with. The extract operation will copy these bits, and so the
194   // mask cannot any holes in it that would clear any of the bits of the
195   // extracted field.
196   if (!LogicalSR) {
197     // If the shift right was arithmetic, it could have included some 1 bits.
198     // It is still ok to generate extract, but only if the mask eliminates
199     // those bits (i.e. M does not have any bits set beyond U).
200     APInt C = APInt::getHighBitsSet(BW, BW-U);
201     if (M.intersects(C) || !M.isMask(W))
202       return false;
203   } else {
204     // Check if M starts with a contiguous sequence of W times 1 bits. Get
205     // the low U bits of M (which eliminates the 0 bits shifted in on the
206     // left), and check if the result is APInt's "mask":
207     if (!M.getLoBits(U).isMask(W))
208       return false;
209   }
210 
211   IRBuilder<> IRB(In);
212   Intrinsic::ID IntId = (BW == 32) ? Intrinsic::hexagon_S2_extractu
213                                    : Intrinsic::hexagon_S2_extractup;
214   Module *Mod = BB->getParent()->getParent();
215   Function *ExtF = Intrinsic::getDeclaration(Mod, IntId);
216   Value *NewIn = IRB.CreateCall(ExtF, {BF, IRB.getInt32(W), IRB.getInt32(SR)});
217   if (SL != 0)
218     NewIn = IRB.CreateShl(NewIn, SL, CSL->getName());
219   In->replaceAllUsesWith(NewIn);
220   return true;
221 }
222 
223 bool HexagonGenExtract::visitBlock(BasicBlock *B) {
224   bool Changed = false;
225 
226   // Depth-first, bottom-up traversal.
227   for (auto *DTN : children<DomTreeNode*>(DT->getNode(B)))
228     Changed |= visitBlock(DTN->getBlock());
229 
230   // Allow limiting the number of generated extracts for debugging purposes.
231   bool HasCutoff = ExtractCutoff.getPosition();
232   unsigned Cutoff = ExtractCutoff;
233 
234   BasicBlock::iterator I = std::prev(B->end()), NextI, Begin = B->begin();
235   while (true) {
236     if (HasCutoff && (ExtractCount >= Cutoff))
237       return Changed;
238     bool Last = (I == Begin);
239     if (!Last)
240       NextI = std::prev(I);
241     Instruction *In = &*I;
242     bool Done = convert(In);
243     if (HasCutoff && Done)
244       ExtractCount++;
245     Changed |= Done;
246     if (Last)
247       break;
248     I = NextI;
249   }
250   return Changed;
251 }
252 
253 bool HexagonGenExtract::runOnFunction(Function &F) {
254   if (skipFunction(F))
255     return false;
256 
257   DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
258   bool Changed;
259 
260   // Traverse the function bottom-up, to see super-expressions before their
261   // sub-expressions.
262   BasicBlock *Entry = GraphTraits<Function*>::getEntryNode(&F);
263   Changed = visitBlock(Entry);
264 
265   return Changed;
266 }
267 
268 FunctionPass *llvm::createHexagonGenExtract() {
269   return new HexagonGenExtract();
270 }
271