xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp (revision f9fd7337f63698f33239c58c07bf430198235a22)
1 //===----------------------- AlignmentFromAssumptions.cpp -----------------===//
2 //                  Set Load/Store Alignments From Assumptions
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements a ScalarEvolution-based transformation to set
11 // the alignments of load, stores and memory intrinsics based on the truth
12 // expressions of assume intrinsics. The primary motivation is to handle
13 // complex alignment assumptions that apply to vector loads and stores that
14 // appear after vectorization and unrolling.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "llvm/InitializePasses.h"
19 #define AA_NAME "alignment-from-assumptions"
20 #define DEBUG_TYPE AA_NAME
21 #include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/Statistic.h"
24 #include "llvm/Analysis/AliasAnalysis.h"
25 #include "llvm/Analysis/AssumptionCache.h"
26 #include "llvm/Analysis/GlobalsModRef.h"
27 #include "llvm/Analysis/LoopInfo.h"
28 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
29 #include "llvm/Analysis/ValueTracking.h"
30 #include "llvm/IR/Constant.h"
31 #include "llvm/IR/Dominators.h"
32 #include "llvm/IR/Instruction.h"
33 #include "llvm/IR/IntrinsicInst.h"
34 #include "llvm/IR/Intrinsics.h"
35 #include "llvm/IR/Module.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "llvm/Transforms/Scalar.h"
39 using namespace llvm;
40 
41 STATISTIC(NumLoadAlignChanged,
42   "Number of loads changed by alignment assumptions");
43 STATISTIC(NumStoreAlignChanged,
44   "Number of stores changed by alignment assumptions");
45 STATISTIC(NumMemIntAlignChanged,
46   "Number of memory intrinsics changed by alignment assumptions");
47 
48 namespace {
49 struct AlignmentFromAssumptions : public FunctionPass {
50   static char ID; // Pass identification, replacement for typeid
51   AlignmentFromAssumptions() : FunctionPass(ID) {
52     initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry());
53   }
54 
55   bool runOnFunction(Function &F) override;
56 
57   void getAnalysisUsage(AnalysisUsage &AU) const override {
58     AU.addRequired<AssumptionCacheTracker>();
59     AU.addRequired<ScalarEvolutionWrapperPass>();
60     AU.addRequired<DominatorTreeWrapperPass>();
61 
62     AU.setPreservesCFG();
63     AU.addPreserved<AAResultsWrapperPass>();
64     AU.addPreserved<GlobalsAAWrapperPass>();
65     AU.addPreserved<LoopInfoWrapperPass>();
66     AU.addPreserved<DominatorTreeWrapperPass>();
67     AU.addPreserved<ScalarEvolutionWrapperPass>();
68   }
69 
70   AlignmentFromAssumptionsPass Impl;
71 };
72 }
73 
74 char AlignmentFromAssumptions::ID = 0;
75 static const char aip_name[] = "Alignment from assumptions";
76 INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions, AA_NAME,
77                       aip_name, false, false)
78 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
79 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
80 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
81 INITIALIZE_PASS_END(AlignmentFromAssumptions, AA_NAME,
82                     aip_name, false, false)
83 
84 FunctionPass *llvm::createAlignmentFromAssumptionsPass() {
85   return new AlignmentFromAssumptions();
86 }
87 
88 // Given an expression for the (constant) alignment, AlignSCEV, and an
89 // expression for the displacement between a pointer and the aligned address,
90 // DiffSCEV, compute the alignment of the displaced pointer if it can be reduced
91 // to a constant. Using SCEV to compute alignment handles the case where
92 // DiffSCEV is a recurrence with constant start such that the aligned offset
93 // is constant. e.g. {16,+,32} % 32 -> 16.
94 static MaybeAlign getNewAlignmentDiff(const SCEV *DiffSCEV,
95                                       const SCEV *AlignSCEV,
96                                       ScalarEvolution *SE) {
97   // DiffUnits = Diff % int64_t(Alignment)
98   const SCEV *DiffUnitsSCEV = SE->getURemExpr(DiffSCEV, AlignSCEV);
99 
100   LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is "
101                     << *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n");
102 
103   if (const SCEVConstant *ConstDUSCEV =
104       dyn_cast<SCEVConstant>(DiffUnitsSCEV)) {
105     int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue();
106 
107     // If the displacement is an exact multiple of the alignment, then the
108     // displaced pointer has the same alignment as the aligned pointer, so
109     // return the alignment value.
110     if (!DiffUnits)
111       return cast<SCEVConstant>(AlignSCEV)->getValue()->getAlignValue();
112 
113     // If the displacement is not an exact multiple, but the remainder is a
114     // constant, then return this remainder (but only if it is a power of 2).
115     uint64_t DiffUnitsAbs = std::abs(DiffUnits);
116     if (isPowerOf2_64(DiffUnitsAbs))
117       return Align(DiffUnitsAbs);
118   }
119 
120   return None;
121 }
122 
123 // There is an address given by an offset OffSCEV from AASCEV which has an
124 // alignment AlignSCEV. Use that information, if possible, to compute a new
125 // alignment for Ptr.
126 static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,
127                              const SCEV *OffSCEV, Value *Ptr,
128                              ScalarEvolution *SE) {
129   const SCEV *PtrSCEV = SE->getSCEV(Ptr);
130   // On a platform with 32-bit allocas, but 64-bit flat/global pointer sizes
131   // (*cough* AMDGPU), the effective SCEV type of AASCEV and PtrSCEV
132   // may disagree. Trunc/extend so they agree.
133   PtrSCEV = SE->getTruncateOrZeroExtend(
134       PtrSCEV, SE->getEffectiveSCEVType(AASCEV->getType()));
135   const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV);
136 
137   // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always
138   // sign-extended OffSCEV to i64, so make sure they agree again.
139   DiffSCEV = SE->getNoopOrSignExtend(DiffSCEV, OffSCEV->getType());
140 
141   // What we really want to know is the overall offset to the aligned
142   // address. This address is displaced by the provided offset.
143   DiffSCEV = SE->getMinusSCEV(DiffSCEV, OffSCEV);
144 
145   LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to "
146                     << *AlignSCEV << " and offset " << *OffSCEV
147                     << " using diff " << *DiffSCEV << "\n");
148 
149   if (MaybeAlign NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE)) {
150     LLVM_DEBUG(dbgs() << "\tnew alignment: " << DebugStr(NewAlignment) << "\n");
151     return *NewAlignment;
152   }
153 
154   if (const SCEVAddRecExpr *DiffARSCEV = dyn_cast<SCEVAddRecExpr>(DiffSCEV)) {
155     // The relative offset to the alignment assumption did not yield a constant,
156     // but we should try harder: if we assume that a is 32-byte aligned, then in
157     // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are
158     // 32-byte aligned, but instead alternate between 32 and 16-byte alignment.
159     // As a result, the new alignment will not be a constant, but can still
160     // be improved over the default (of 4) to 16.
161 
162     const SCEV *DiffStartSCEV = DiffARSCEV->getStart();
163     const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE);
164 
165     LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start "
166                       << *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n");
167 
168     // Now compute the new alignment using the displacement to the value in the
169     // first iteration, and also the alignment using the per-iteration delta.
170     // If these are the same, then use that answer. Otherwise, use the smaller
171     // one, but only if it divides the larger one.
172     MaybeAlign NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE);
173     MaybeAlign NewIncAlignment =
174         getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE);
175 
176     LLVM_DEBUG(dbgs() << "\tnew start alignment: " << DebugStr(NewAlignment)
177                       << "\n");
178     LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << DebugStr(NewIncAlignment)
179                       << "\n");
180 
181     if (!NewAlignment || !NewIncAlignment)
182       return Align(1);
183 
184     const Align NewAlign = *NewAlignment;
185     const Align NewIncAlign = *NewIncAlignment;
186     if (NewAlign > NewIncAlign) {
187       LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: "
188                         << DebugStr(NewIncAlign) << "\n");
189       return NewIncAlign;
190     }
191     if (NewIncAlign > NewAlign) {
192       LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)
193                         << "\n");
194       return NewAlign;
195     }
196     assert(NewIncAlign == NewAlign);
197     LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)
198                       << "\n");
199     return NewAlign;
200   }
201 
202   return Align(1);
203 }
204 
205 bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I,
206                                                         Value *&AAPtr,
207                                                         const SCEV *&AlignSCEV,
208                                                         const SCEV *&OffSCEV) {
209   // An alignment assume must be a statement about the least-significant
210   // bits of the pointer being zero, possibly with some offset.
211   ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0));
212   if (!ICI)
213     return false;
214 
215   // This must be an expression of the form: x & m == 0.
216   if (ICI->getPredicate() != ICmpInst::ICMP_EQ)
217     return false;
218 
219   // Swap things around so that the RHS is 0.
220   Value *CmpLHS = ICI->getOperand(0);
221   Value *CmpRHS = ICI->getOperand(1);
222   const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS);
223   const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS);
224   if (CmpLHSSCEV->isZero())
225     std::swap(CmpLHS, CmpRHS);
226   else if (!CmpRHSSCEV->isZero())
227     return false;
228 
229   BinaryOperator *CmpBO = dyn_cast<BinaryOperator>(CmpLHS);
230   if (!CmpBO || CmpBO->getOpcode() != Instruction::And)
231     return false;
232 
233   // Swap things around so that the right operand of the and is a constant
234   // (the mask); we cannot deal with variable masks.
235   Value *AndLHS = CmpBO->getOperand(0);
236   Value *AndRHS = CmpBO->getOperand(1);
237   const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS);
238   const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS);
239   if (isa<SCEVConstant>(AndLHSSCEV)) {
240     std::swap(AndLHS, AndRHS);
241     std::swap(AndLHSSCEV, AndRHSSCEV);
242   }
243 
244   const SCEVConstant *MaskSCEV = dyn_cast<SCEVConstant>(AndRHSSCEV);
245   if (!MaskSCEV)
246     return false;
247 
248   // The mask must have some trailing ones (otherwise the condition is
249   // trivial and tells us nothing about the alignment of the left operand).
250   unsigned TrailingOnes = MaskSCEV->getAPInt().countTrailingOnes();
251   if (!TrailingOnes)
252     return false;
253 
254   // Cap the alignment at the maximum with which LLVM can deal (and make sure
255   // we don't overflow the shift).
256   uint64_t Alignment;
257   TrailingOnes = std::min(TrailingOnes,
258     unsigned(sizeof(unsigned) * CHAR_BIT - 1));
259   Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment);
260 
261   Type *Int64Ty = Type::getInt64Ty(I->getParent()->getParent()->getContext());
262   AlignSCEV = SE->getConstant(Int64Ty, Alignment);
263 
264   // The LHS might be a ptrtoint instruction, or it might be the pointer
265   // with an offset.
266   AAPtr = nullptr;
267   OffSCEV = nullptr;
268   if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(AndLHS)) {
269     AAPtr = PToI->getPointerOperand();
270     OffSCEV = SE->getZero(Int64Ty);
271   } else if (const SCEVAddExpr* AndLHSAddSCEV =
272              dyn_cast<SCEVAddExpr>(AndLHSSCEV)) {
273     // Try to find the ptrtoint; subtract it and the rest is the offset.
274     for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(),
275          JE = AndLHSAddSCEV->op_end(); J != JE; ++J)
276       if (const SCEVUnknown *OpUnk = dyn_cast<SCEVUnknown>(*J))
277         if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(OpUnk->getValue())) {
278           AAPtr = PToI->getPointerOperand();
279           OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J);
280           break;
281         }
282   }
283 
284   if (!AAPtr)
285     return false;
286 
287   // Sign extend the offset to 64 bits (so that it is like all of the other
288   // expressions).
289   unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits();
290   if (OffSCEVBits < 64)
291     OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty);
292   else if (OffSCEVBits > 64)
293     return false;
294 
295   AAPtr = AAPtr->stripPointerCasts();
296   return true;
297 }
298 
299 bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) {
300   Value *AAPtr;
301   const SCEV *AlignSCEV, *OffSCEV;
302   if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV))
303     return false;
304 
305   // Skip ConstantPointerNull and UndefValue.  Assumptions on these shouldn't
306   // affect other users.
307   if (isa<ConstantData>(AAPtr))
308     return false;
309 
310   const SCEV *AASCEV = SE->getSCEV(AAPtr);
311 
312   // Apply the assumption to all other users of the specified pointer.
313   SmallPtrSet<Instruction *, 32> Visited;
314   SmallVector<Instruction*, 16> WorkList;
315   for (User *J : AAPtr->users()) {
316     if (J == ACall)
317       continue;
318 
319     if (Instruction *K = dyn_cast<Instruction>(J))
320       if (isValidAssumeForContext(ACall, K, DT))
321         WorkList.push_back(K);
322   }
323 
324   while (!WorkList.empty()) {
325     Instruction *J = WorkList.pop_back_val();
326     if (LoadInst *LI = dyn_cast<LoadInst>(J)) {
327       Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
328                                            LI->getPointerOperand(), SE);
329       if (NewAlignment > LI->getAlign()) {
330         LI->setAlignment(NewAlignment);
331         ++NumLoadAlignChanged;
332       }
333     } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {
334       Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,
335                                            SI->getPointerOperand(), SE);
336       if (NewAlignment > SI->getAlign()) {
337         SI->setAlignment(NewAlignment);
338         ++NumStoreAlignChanged;
339       }
340     } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {
341       Align NewDestAlignment =
342           getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MI->getDest(), SE);
343 
344       LLVM_DEBUG(dbgs() << "\tmem inst: " << DebugStr(NewDestAlignment)
345                         << "\n";);
346       if (NewDestAlignment > *MI->getDestAlign()) {
347         MI->setDestAlignment(NewDestAlignment);
348         ++NumMemIntAlignChanged;
349       }
350 
351       // For memory transfers, there is also a source alignment that
352       // can be set.
353       if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {
354         Align NewSrcAlignment =
355             getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MTI->getSource(), SE);
356 
357         LLVM_DEBUG(dbgs() << "\tmem trans: " << DebugStr(NewSrcAlignment)
358                           << "\n";);
359 
360         if (NewSrcAlignment > *MTI->getSourceAlign()) {
361           MTI->setSourceAlignment(NewSrcAlignment);
362           ++NumMemIntAlignChanged;
363         }
364       }
365     }
366 
367     // Now that we've updated that use of the pointer, look for other uses of
368     // the pointer to update.
369     Visited.insert(J);
370     for (User *UJ : J->users()) {
371       Instruction *K = cast<Instruction>(UJ);
372       if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DT))
373         WorkList.push_back(K);
374     }
375   }
376 
377   return true;
378 }
379 
380 bool AlignmentFromAssumptions::runOnFunction(Function &F) {
381   if (skipFunction(F))
382     return false;
383 
384   auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
385   ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
386   DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
387 
388   return Impl.runImpl(F, AC, SE, DT);
389 }
390 
391 bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC,
392                                            ScalarEvolution *SE_,
393                                            DominatorTree *DT_) {
394   SE = SE_;
395   DT = DT_;
396 
397   bool Changed = false;
398   for (auto &AssumeVH : AC.assumptions())
399     if (AssumeVH)
400       Changed |= processAssumption(cast<CallInst>(AssumeVH));
401 
402   return Changed;
403 }
404 
405 PreservedAnalyses
406 AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) {
407 
408   AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);
409   ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
410   DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);
411   if (!runImpl(F, AC, &SE, &DT))
412     return PreservedAnalyses::all();
413 
414   PreservedAnalyses PA;
415   PA.preserveSet<CFGAnalyses>();
416   PA.preserve<AAManager>();
417   PA.preserve<ScalarEvolutionAnalysis>();
418   PA.preserve<GlobalsAA>();
419   return PA;
420 }
421