1 //- NVPTXForwardParams.cpp - NVPTX Forward Device Params Removing Local Copy -//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // PTX supports 2 methods of accessing device function parameters:
10 //
11 // - "simple" case: If a parameters is only loaded, and all loads can address
12 // the parameter via a constant offset, then the parameter may be loaded via
13 // the ".param" address space. This case is not possible if the parameters
14 // is stored to or has it's address taken. This method is preferable when
15 // possible. Ex:
16 //
17 // ld.param.u32 %r1, [foo_param_1];
18 // ld.param.u32 %r2, [foo_param_1+4];
19 //
20 // - "move param" case: For more complex cases the address of the param may be
21 // placed in a register via a "mov" instruction. This "mov" also implicitly
22 // moves the param to the ".local" address space and allows for it to be
23 // written to. This essentially defers the responsibilty of the byval copy
24 // to the PTX calling convention.
25 //
26 // mov.b64 %rd1, foo_param_0;
27 // st.local.u32 [%rd1], 42;
28 // add.u64 %rd3, %rd1, %rd2;
29 // ld.local.u32 %r2, [%rd3];
30 //
31 // In NVPTXLowerArgs and SelectionDAG, we pessimistically assume that all
32 // parameters will use the "move param" case and the local address space. This
33 // pass is responsible for switching to the "simple" case when possible, as it
34 // is more efficient.
35 //
36 // We do this by simply traversing uses of the param "mov" instructions an
37 // trivially checking if they are all loads.
38 //
39 //===----------------------------------------------------------------------===//
40
41 #include "NVPTX.h"
42 #include "llvm/ADT/SmallVector.h"
43 #include "llvm/CodeGen/MachineFunctionPass.h"
44 #include "llvm/CodeGen/MachineInstr.h"
45 #include "llvm/CodeGen/MachineOperand.h"
46 #include "llvm/CodeGen/MachineRegisterInfo.h"
47 #include "llvm/CodeGen/TargetRegisterInfo.h"
48 #include "llvm/Support/ErrorHandling.h"
49
50 using namespace llvm;
51
traverseMoveUse(MachineInstr & U,const MachineRegisterInfo & MRI,SmallVectorImpl<MachineInstr * > & RemoveList,SmallVectorImpl<MachineInstr * > & LoadInsts)52 static bool traverseMoveUse(MachineInstr &U, const MachineRegisterInfo &MRI,
53 SmallVectorImpl<MachineInstr *> &RemoveList,
54 SmallVectorImpl<MachineInstr *> &LoadInsts) {
55 switch (U.getOpcode()) {
56 case NVPTX::LD_i16:
57 case NVPTX::LD_i32:
58 case NVPTX::LD_i64:
59 case NVPTX::LD_i8:
60 case NVPTX::LDV_i16_v2:
61 case NVPTX::LDV_i16_v4:
62 case NVPTX::LDV_i32_v2:
63 case NVPTX::LDV_i32_v4:
64 case NVPTX::LDV_i64_v2:
65 case NVPTX::LDV_i64_v4:
66 case NVPTX::LDV_i8_v2:
67 case NVPTX::LDV_i8_v4: {
68 LoadInsts.push_back(&U);
69 return true;
70 }
71 case NVPTX::cvta_local:
72 case NVPTX::cvta_local_64:
73 case NVPTX::cvta_to_local:
74 case NVPTX::cvta_to_local_64: {
75 for (auto &U2 : MRI.use_instructions(U.operands_begin()->getReg()))
76 if (!traverseMoveUse(U2, MRI, RemoveList, LoadInsts))
77 return false;
78
79 RemoveList.push_back(&U);
80 return true;
81 }
82 default:
83 return false;
84 }
85 }
86
eliminateMove(MachineInstr & Mov,const MachineRegisterInfo & MRI,SmallVectorImpl<MachineInstr * > & RemoveList)87 static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI,
88 SmallVectorImpl<MachineInstr *> &RemoveList) {
89 SmallVector<MachineInstr *, 16> MaybeRemoveList;
90 SmallVector<MachineInstr *, 16> LoadInsts;
91
92 for (auto &U : MRI.use_instructions(Mov.operands_begin()->getReg()))
93 if (!traverseMoveUse(U, MRI, MaybeRemoveList, LoadInsts))
94 return false;
95
96 RemoveList.append(MaybeRemoveList);
97 RemoveList.push_back(&Mov);
98
99 const MachineOperand *ParamSymbol = Mov.uses().begin();
100 assert(ParamSymbol->isSymbol());
101
102 constexpr unsigned LDInstBasePtrOpIdx = 5;
103 constexpr unsigned LDInstAddrSpaceOpIdx = 2;
104 for (auto *LI : LoadInsts) {
105 (LI->uses().begin() + LDInstBasePtrOpIdx)
106 ->ChangeToES(ParamSymbol->getSymbolName());
107 (LI->uses().begin() + LDInstAddrSpaceOpIdx)
108 ->ChangeToImmediate(NVPTX::AddressSpace::Param);
109 }
110 return true;
111 }
112
forwardDeviceParams(MachineFunction & MF)113 static bool forwardDeviceParams(MachineFunction &MF) {
114 const auto &MRI = MF.getRegInfo();
115
116 bool Changed = false;
117 SmallVector<MachineInstr *, 16> RemoveList;
118 for (auto &MI : make_early_inc_range(*MF.begin()))
119 if (MI.getOpcode() == NVPTX::MOV32_PARAM ||
120 MI.getOpcode() == NVPTX::MOV64_PARAM)
121 Changed |= eliminateMove(MI, MRI, RemoveList);
122
123 for (auto *MI : RemoveList)
124 MI->eraseFromParent();
125
126 return Changed;
127 }
128
129 /// ----------------------------------------------------------------------------
130 /// Pass (Manager) Boilerplate
131 /// ----------------------------------------------------------------------------
132
133 namespace {
134 struct NVPTXForwardParamsPass : public MachineFunctionPass {
135 static char ID;
NVPTXForwardParamsPass__anon4673056e0111::NVPTXForwardParamsPass136 NVPTXForwardParamsPass() : MachineFunctionPass(ID) {}
137
138 bool runOnMachineFunction(MachineFunction &MF) override;
139
getAnalysisUsage__anon4673056e0111::NVPTXForwardParamsPass140 void getAnalysisUsage(AnalysisUsage &AU) const override {
141 MachineFunctionPass::getAnalysisUsage(AU);
142 }
143 };
144 } // namespace
145
146 char NVPTXForwardParamsPass::ID = 0;
147
148 INITIALIZE_PASS(NVPTXForwardParamsPass, "nvptx-forward-params",
149 "NVPTX Forward Params", false, false)
150
runOnMachineFunction(MachineFunction & MF)151 bool NVPTXForwardParamsPass::runOnMachineFunction(MachineFunction &MF) {
152 return forwardDeviceParams(MF);
153 }
154
createNVPTXForwardParamsPass()155 MachineFunctionPass *llvm::createNVPTXForwardParamsPass() {
156 return new NVPTXForwardParamsPass();
157 }
158