xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/MatrixUtils.cpp (revision af23369a6deaaeb612ab266eb88b8bb8d560c322)
1 //===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utilities for generating tiled loops for matrix operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "llvm/Transforms/Utils/MatrixUtils.h"
14 #include "llvm/Analysis/DomTreeUpdater.h"
15 #include "llvm/Analysis/LoopInfo.h"
16 #include "llvm/IR/BasicBlock.h"
17 #include "llvm/IR/Dominators.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/IR/Type.h"
20 
21 using namespace llvm;
22 
23 BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit,
24                                  Value *Bound, Value *Step, StringRef Name,
25                                  IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
26                                  LoopInfo &LI) {
27   LLVMContext &Ctx = Preheader->getContext();
28   BasicBlock *Header = BasicBlock::Create(
29       Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit);
30   BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body",
31                                         Header->getParent(), Exit);
32   BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch",
33                                          Header->getParent(), Exit);
34 
35   Type *I32Ty = Type::getInt64Ty(Ctx);
36   BranchInst::Create(Body, Header);
37   BranchInst::Create(Latch, Body);
38   PHINode *IV =
39       PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator());
40   IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader);
41 
42   B.SetInsertPoint(Latch);
43   Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
44   Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
45   BranchInst::Create(Header, Exit, Cond, Latch);
46   IV->addIncoming(Inc, Latch);
47 
48   BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
49   BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
50   PreheaderBr->setSuccessor(0, Header);
51   DTU.applyUpdatesPermissive({
52       {DominatorTree::Delete, Preheader, Tmp},
53       {DominatorTree::Insert, Header, Body},
54       {DominatorTree::Insert, Body, Latch},
55       {DominatorTree::Insert, Latch, Header},
56       {DominatorTree::Insert, Latch, Exit},
57       {DominatorTree::Insert, Preheader, Header},
58   });
59 
60   L->addBasicBlockToLoop(Header, LI);
61   L->addBasicBlockToLoop(Body, LI);
62   L->addBasicBlockToLoop(Latch, LI);
63   return Body;
64 }
65 
66 // Creates the following loop nest skeleton:
67 //  for C = 0; C < NumColumns; C += TileSize
68 //    for R = 0; R < NumRows; R += TileSize
69 //      for K = 0; K < Inner ; K += TileSize
70 BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End,
71                                        IRBuilderBase &B, DomTreeUpdater &DTU,
72                                        LoopInfo &LI) {
73   Loop *ColumnLoopInfo = LI.AllocateLoop();
74   Loop *RowLoopInfo = LI.AllocateLoop();
75   Loop *KLoopInfo = LI.AllocateLoop();
76   RowLoopInfo->addChildLoop(KLoopInfo);
77   ColumnLoopInfo->addChildLoop(RowLoopInfo);
78   if (Loop *ParentL = LI.getLoopFor(Start))
79     ParentL->addChildLoop(ColumnLoopInfo);
80   else
81     LI.addTopLevelLoop(ColumnLoopInfo);
82 
83   BasicBlock *ColBody =
84       CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize),
85                  "cols", B, DTU, ColumnLoopInfo, LI);
86   ColumnLoop.Latch = ColBody->getSingleSuccessor();
87   BasicBlock *RowBody =
88       CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows),
89                  B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI);
90   RowLoop.Latch = RowBody->getSingleSuccessor();
91 
92   BasicBlock *InnerBody =
93       CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner),
94                  B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI);
95   KLoop.Latch = InnerBody->getSingleSuccessor();
96   ColumnLoop.Header = ColBody->getSinglePredecessor();
97   RowLoop.Header = RowBody->getSinglePredecessor();
98   KLoop.Header = InnerBody->getSinglePredecessor();
99   RowLoop.Index = &*RowLoop.Header->begin();
100   ColumnLoop.Index = &*ColumnLoop.Header->begin();
101   KLoop.Index = &*KLoop.Header->begin();
102 
103   return InnerBody;
104 }
105