xref: /freebsd/contrib/llvm-project/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //- NVPTXForwardParams.cpp - NVPTX Forward Device Params Removing Local Copy -//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // PTX supports 2 methods of accessing device function parameters:
10 //
11 //   - "simple" case: If a parameters is only loaded, and all loads can address
12 //     the parameter via a constant offset, then the parameter may be loaded via
13 //     the ".param" address space. This case is not possible if the parameters
14 //     is stored to or has it's address taken. This method is preferable when
15 //     possible. Ex:
16 //
17 //            ld.param.u32    %r1, [foo_param_1];
18 //            ld.param.u32    %r2, [foo_param_1+4];
19 //
20 //   - "move param" case: For more complex cases the address of the param may be
21 //     placed in a register via a "mov" instruction. This "mov" also implicitly
22 //     moves the param to the ".local" address space and allows for it to be
23 //     written to. This essentially defers the responsibilty of the byval copy
24 //     to the PTX calling convention.
25 //
26 //            mov.b64         %rd1, foo_param_0;
27 //            st.local.u32    [%rd1], 42;
28 //            add.u64         %rd3, %rd1, %rd2;
29 //            ld.local.u32    %r2, [%rd3];
30 //
31 // In NVPTXLowerArgs and SelectionDAG, we pessimistically assume that all
32 // parameters will use the "move param" case and the local address space. This
33 // pass is responsible for switching to the "simple" case when possible, as it
34 // is more efficient.
35 //
36 // We do this by simply traversing uses of the param "mov" instructions an
37 // trivially checking if they are all loads.
38 //
39 //===----------------------------------------------------------------------===//
40 
41 #include "NVPTX.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/CodeGen/MachineFunctionPass.h"
44 #include "llvm/CodeGen/MachineInstr.h"
45 #include "llvm/CodeGen/MachineOperand.h"
46 #include "llvm/CodeGen/MachineRegisterInfo.h"
47 #include "llvm/CodeGen/TargetRegisterInfo.h"
48 #include "llvm/Support/ErrorHandling.h"
49 
50 using namespace llvm;
51 
traverseMoveUse(MachineInstr & U,const MachineRegisterInfo & MRI,SmallVectorImpl<MachineInstr * > & RemoveList,SmallVectorImpl<MachineInstr * > & LoadInsts)52 static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI,
53                             SmallVectorImpl<MachineInstr *> &RemoveList,
54                             SmallVectorImpl<MachineInstr *> &LoadInsts) {
55   switch (U.getOpcode()) {
56   case NVPTX::LD_i16:
57   case NVPTX::LD_i32:
58   case NVPTX::LD_i64:
59   case NVPTX::LD_i8:
60   case NVPTX::LDV_i16_v2:
61   case NVPTX::LDV_i16_v4:
62   case NVPTX::LDV_i32_v2:
63   case NVPTX::LDV_i32_v4:
64   case NVPTX::LDV_i64_v2:
65   case NVPTX::LDV_i64_v4:
66   case NVPTX::LDV_i8_v2:
67   case NVPTX::LDV_i8_v4: {
68     LoadInsts.push_back(&U);
69     return true;
70   }
71   case NVPTX::cvta_local:
72   case NVPTX::cvta_local_64:
73   case NVPTX::cvta_to_local:
74   case NVPTX::cvta_to_local_64: {
75     for (auto &U2 : MRI.use_instructions(U.operands_begin()->getReg()))
76       if (!traverseMoveUse(U2, MRI, RemoveList, LoadInsts))
77         return false;
78 
79     RemoveList.push_back(&U);
80     return true;
81   }
82   default:
83     return false;
84   }
85 }
86 
eliminateMove(MachineInstr & Mov,const MachineRegisterInfo & MRI,SmallVectorImpl<MachineInstr * > & RemoveList)87 static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
88                           SmallVectorImpl<MachineInstr *> &RemoveList) {
89   SmallVector<MachineInstr *, 16> MaybeRemoveList;
90   SmallVector<MachineInstr *, 16> LoadInsts;
91 
92   for (auto &U : MRI.use_instructions(Mov.operands_begin()->getReg()))
93     if (!traverseMoveUse(U, MRI, MaybeRemoveList, LoadInsts))
94       return false;
95 
96   RemoveList.append(MaybeRemoveList);
97   RemoveList.push_back(&Mov);
98 
99   const MachineOperand *ParamSymbol = Mov.uses().begin();
100   assert(ParamSymbol->isSymbol());
101 
102   constexpr unsigned LDInstBasePtrOpIdx = 5;
103   constexpr unsigned LDInstAddrSpaceOpIdx = 2;
104   for (auto *LI : LoadInsts) {
105     (LI->uses().begin() + LDInstBasePtrOpIdx)
106         ->ChangeToES(ParamSymbol->getSymbolName());
107     (LI->uses().begin() + LDInstAddrSpaceOpIdx)
108         ->ChangeToImmediate(NVPTX::AddressSpace::Param);
109   }
110   return true;
111 }
112 
forwardDeviceParams(MachineFunction & MF)113 static bool forwardDeviceParams(MachineFunction &MF) {
114   const auto &MRI = MF.getRegInfo();
115 
116   bool Changed = false;
117   SmallVector<MachineInstr *, 16> RemoveList;
118   for (auto &MI : make_early_inc_range(*MF.begin()))
119     if (MI.getOpcode() == NVPTX::MOV32_PARAM ||
120         MI.getOpcode() == NVPTX::MOV64_PARAM)
121       Changed |= eliminateMove(MI, MRI, RemoveList);
122 
123   for (auto *MI : RemoveList)
124     MI->eraseFromParent();
125 
126   return Changed;
127 }
128 
129 /// ----------------------------------------------------------------------------
130 ///                       Pass (Manager) Boilerplate
131 /// ----------------------------------------------------------------------------
132 
133 namespace {
134 struct NVPTXForwardParamsPass : public MachineFunctionPass {
135   static char ID;
NVPTXForwardParamsPass__anon4673056e0111::NVPTXForwardParamsPass136   NVPTXForwardParamsPass() : MachineFunctionPass(ID) {}
137 
138   bool runOnMachineFunction(MachineFunction &MF) override;
139 
getAnalysisUsage__anon4673056e0111::NVPTXForwardParamsPass140   void getAnalysisUsage(AnalysisUsage &AU) const override {
141     MachineFunctionPass::getAnalysisUsage(AU);
142   }
143 };
144 } // namespace
145 
146 char NVPTXForwardParamsPass::ID = 0;
147 
148 INITIALIZE_PASS(NVPTXForwardParamsPass, "nvptx-forward-params",
149                 "NVPTX Forward Params", false, false)
150 
runOnMachineFunction(MachineFunction & MF)151 bool NVPTXForwardParamsPass::runOnMachineFunction(MachineFunction &MF) {
152   return forwardDeviceParams(MF);
153 }
154 
createNVPTXForwardParamsPass()155 MachineFunctionPass *llvm::createNVPTXForwardParamsPass() {
156   return new NVPTXForwardParamsPass();
157 }
158