1 //===- MatrixUtils.cpp - Utilities to lower matrix intrinsics ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Utilities for generating tiled loops for matrix operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "llvm/Transforms/Utils/MatrixUtils.h" 14 #include "llvm/Analysis/DomTreeUpdater.h" 15 #include "llvm/Analysis/LoopInfo.h" 16 #include "llvm/IR/BasicBlock.h" 17 #include "llvm/IR/Dominators.h" 18 #include "llvm/IR/IRBuilder.h" 19 #include "llvm/IR/Type.h" 20 21 using namespace llvm; 22 23 BasicBlock *TileInfo::CreateLoop(BasicBlock *Preheader, BasicBlock *Exit, 24 Value *Bound, Value *Step, StringRef Name, 25 IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L, 26 LoopInfo &LI) { 27 LLVMContext &Ctx = Preheader->getContext(); 28 BasicBlock *Header = BasicBlock::Create( 29 Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit); 30 BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body", 31 Header->getParent(), Exit); 32 BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch", 33 Header->getParent(), Exit); 34 35 Type *I32Ty = Type::getInt64Ty(Ctx); 36 BranchInst::Create(Body, Header); 37 BranchInst::Create(Latch, Body); 38 PHINode *IV = 39 PHINode::Create(I32Ty, 2, Name + ".iv", Header->getTerminator()->getIterator()); 40 IV->addIncoming(ConstantInt::get(I32Ty, 0), Preheader); 41 42 B.SetInsertPoint(Latch); 43 Value *Inc = B.CreateAdd(IV, Step, Name + ".step"); 44 Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond"); 45 BranchInst::Create(Header, Exit, Cond, Latch); 46 IV->addIncoming(Inc, Latch); 47 48 BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator()); 49 BasicBlock *Tmp = PreheaderBr->getSuccessor(0); 50 PreheaderBr->setSuccessor(0, Header); 51 DTU.applyUpdatesPermissive({ 52 {DominatorTree::Delete, Preheader, Tmp}, 53 {DominatorTree::Insert, Header, Body}, 54 {DominatorTree::Insert, Body, Latch}, 55 {DominatorTree::Insert, Latch, Header}, 56 {DominatorTree::Insert, Latch, Exit}, 57 {DominatorTree::Insert, Preheader, Header}, 58 }); 59 60 L->addBasicBlockToLoop(Header, LI); 61 L->addBasicBlockToLoop(Body, LI); 62 L->addBasicBlockToLoop(Latch, LI); 63 return Body; 64 } 65 66 // Creates the following loop nest skeleton: 67 // for C = 0; C < NumColumns; C += TileSize 68 // for R = 0; R < NumRows; R += TileSize 69 // for K = 0; K < Inner ; K += TileSize 70 BasicBlock *TileInfo::CreateTiledLoops(BasicBlock *Start, BasicBlock *End, 71 IRBuilderBase &B, DomTreeUpdater &DTU, 72 LoopInfo &LI) { 73 Loop *ColumnLoopInfo = LI.AllocateLoop(); 74 Loop *RowLoopInfo = LI.AllocateLoop(); 75 Loop *KLoopInfo = LI.AllocateLoop(); 76 RowLoopInfo->addChildLoop(KLoopInfo); 77 ColumnLoopInfo->addChildLoop(RowLoopInfo); 78 if (Loop *ParentL = LI.getLoopFor(Start)) 79 ParentL->addChildLoop(ColumnLoopInfo); 80 else 81 LI.addTopLevelLoop(ColumnLoopInfo); 82 83 BasicBlock *ColBody = 84 CreateLoop(Start, End, B.getInt64(NumColumns), B.getInt64(TileSize), 85 "cols", B, DTU, ColumnLoopInfo, LI); 86 ColumnLoop.Latch = ColBody->getSingleSuccessor(); 87 BasicBlock *RowBody = 88 CreateLoop(ColBody, ColumnLoop.Latch, B.getInt64(NumRows), 89 B.getInt64(TileSize), "rows", B, DTU, RowLoopInfo, LI); 90 RowLoop.Latch = RowBody->getSingleSuccessor(); 91 92 BasicBlock *InnerBody = 93 CreateLoop(RowBody, RowLoop.Latch, B.getInt64(NumInner), 94 B.getInt64(TileSize), "inner", B, DTU, KLoopInfo, LI); 95 KLoop.Latch = InnerBody->getSingleSuccessor(); 96 ColumnLoop.Header = ColBody->getSinglePredecessor(); 97 RowLoop.Header = RowBody->getSinglePredecessor(); 98 KLoop.Header = InnerBody->getSinglePredecessor(); 99 RowLoop.Index = &*RowLoop.Header->begin(); 100 ColumnLoop.Index = &*ColumnLoop.Header->begin(); 101 KLoop.Index = &*KLoop.Header->begin(); 102 103 return InnerBody; 104 } 105