1 //===-- NVPTXAsmPrinter.h - NVPTX LLVM assembly writer ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a printer that converts from our internal representation 10 // of machine-dependent LLVM code to NVPTX assembly language. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H 15 #define LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H 16 17 #include "NVPTX.h" 18 #include "NVPTXSubtarget.h" 19 #include "NVPTXTargetMachine.h" 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/ADT/StringRef.h" 23 #include "llvm/CodeGen/AsmPrinter.h" 24 #include "llvm/CodeGen/MachineFunction.h" 25 #include "llvm/CodeGen/MachineLoopInfo.h" 26 #include "llvm/IR/Constants.h" 27 #include "llvm/IR/DebugLoc.h" 28 #include "llvm/IR/DerivedTypes.h" 29 #include "llvm/IR/Function.h" 30 #include "llvm/IR/GlobalValue.h" 31 #include "llvm/IR/Value.h" 32 #include "llvm/MC/MCExpr.h" 33 #include "llvm/MC/MCStreamer.h" 34 #include "llvm/MC/MCSymbol.h" 35 #include "llvm/Pass.h" 36 #include "llvm/Support/Casting.h" 37 #include "llvm/Support/Compiler.h" 38 #include "llvm/Support/ErrorHandling.h" 39 #include "llvm/Support/raw_ostream.h" 40 #include "llvm/Target/TargetMachine.h" 41 #include <algorithm> 42 #include <cassert> 43 #include <map> 44 #include <memory> 45 #include <string> 46 #include <vector> 47 48 // The ptx syntax and format is very different from that usually seem in a .s 49 // file, 50 // therefore we are not able to use the MCAsmStreamer interface here. 51 // 52 // We are handcrafting the output method here. 53 // 54 // A better approach is to clone the MCAsmStreamer to a MCPTXAsmStreamer 55 // (subclass of MCStreamer). 56 57 namespace llvm { 58 59 class MCOperand; 60 61 class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter { 62 63 class AggBuffer { 64 // Used to buffer the emitted string for initializing global 65 // aggregates. 66 // 67 // Normally an aggregate (array, vector or structure) is emitted 68 // as a u8[]. However, if one element/field of the aggregate 69 // is a non-NULL address, then the aggregate is emitted as u32[] 70 // or u64[]. 71 // 72 // We first layout the aggregate in 'buffer' in bytes, except for 73 // those symbol addresses. For the i-th symbol address in the 74 //aggregate, its corresponding 4-byte or 8-byte elements in 'buffer' 75 // are filled with 0s. symbolPosInBuffer[i-1] records its position 76 // in 'buffer', and Symbols[i-1] records the Value*. 77 // 78 // Once we have this AggBuffer setup, we can choose how to print 79 // it out. 80 public: 81 unsigned numSymbols; // number of symbol addresses 82 83 private: 84 const unsigned size; // size of the buffer in bytes 85 std::vector<unsigned char> buffer; // the buffer 86 SmallVector<unsigned, 4> symbolPosInBuffer; 87 SmallVector<const Value *, 4> Symbols; 88 // SymbolsBeforeStripping[i] is the original form of Symbols[i] before 89 // stripping pointer casts, i.e., 90 // Symbols[i] == SymbolsBeforeStripping[i]->stripPointerCasts(). 91 // 92 // We need to keep these values because AggBuffer::print decides whether to 93 // emit a "generic()" cast for Symbols[i] depending on the address space of 94 // SymbolsBeforeStripping[i]. 95 SmallVector<const Value *, 4> SymbolsBeforeStripping; 96 unsigned curpos; 97 raw_ostream &O; 98 NVPTXAsmPrinter &AP; 99 bool EmitGeneric; 100 101 public: 102 AggBuffer(unsigned size, raw_ostream &O, NVPTXAsmPrinter &AP) 103 : size(size), buffer(size), O(O), AP(AP) { 104 curpos = 0; 105 numSymbols = 0; 106 EmitGeneric = AP.EmitGeneric; 107 } 108 109 // Copy Num bytes from Ptr. 110 // if Bytes > Num, zero fill up to Bytes. 111 unsigned addBytes(unsigned char *Ptr, int Num, int Bytes) { 112 assert((curpos + Num) <= size); 113 assert((curpos + Bytes) <= size); 114 for (int i = 0; i < Num; ++i) { 115 buffer[curpos] = Ptr[i]; 116 curpos++; 117 } 118 for (int i = Num; i < Bytes; ++i) { 119 buffer[curpos] = 0; 120 curpos++; 121 } 122 return curpos; 123 } 124 125 unsigned addZeros(int Num) { 126 assert((curpos + Num) <= size); 127 for (int i = 0; i < Num; ++i) { 128 buffer[curpos] = 0; 129 curpos++; 130 } 131 return curpos; 132 } 133 134 void addSymbol(const Value *GVar, const Value *GVarBeforeStripping) { 135 symbolPosInBuffer.push_back(curpos); 136 Symbols.push_back(GVar); 137 SymbolsBeforeStripping.push_back(GVarBeforeStripping); 138 numSymbols++; 139 } 140 141 void print() { 142 if (numSymbols == 0) { 143 // print out in bytes 144 for (unsigned i = 0; i < size; i++) { 145 if (i) 146 O << ", "; 147 O << (unsigned int) buffer[i]; 148 } 149 } else { 150 // print out in 4-bytes or 8-bytes 151 unsigned int pos = 0; 152 unsigned int nSym = 0; 153 unsigned int nextSymbolPos = symbolPosInBuffer[nSym]; 154 unsigned int nBytes = 4; 155 if (static_cast<const NVPTXTargetMachine &>(AP.TM).is64Bit()) 156 nBytes = 8; 157 for (pos = 0; pos < size; pos += nBytes) { 158 if (pos) 159 O << ", "; 160 if (pos == nextSymbolPos) { 161 const Value *v = Symbols[nSym]; 162 const Value *v0 = SymbolsBeforeStripping[nSym]; 163 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) { 164 MCSymbol *Name = AP.getSymbol(GVar); 165 PointerType *PTy = dyn_cast<PointerType>(v0->getType()); 166 bool IsNonGenericPointer = false; // Is v0 a non-generic pointer? 167 if (PTy && PTy->getAddressSpace() != 0) { 168 IsNonGenericPointer = true; 169 } 170 if (EmitGeneric && !isa<Function>(v) && !IsNonGenericPointer) { 171 O << "generic("; 172 Name->print(O, AP.MAI); 173 O << ")"; 174 } else { 175 Name->print(O, AP.MAI); 176 } 177 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) { 178 const MCExpr *Expr = 179 AP.lowerConstantForGV(cast<Constant>(CExpr), false); 180 AP.printMCExpr(*Expr, O); 181 } else 182 llvm_unreachable("symbol type unknown"); 183 nSym++; 184 if (nSym >= numSymbols) 185 nextSymbolPos = size + 1; 186 else 187 nextSymbolPos = symbolPosInBuffer[nSym]; 188 } else if (nBytes == 4) 189 O << *(unsigned int *)(&buffer[pos]); 190 else 191 O << *(unsigned long long *)(&buffer[pos]); 192 } 193 } 194 } 195 }; 196 197 friend class AggBuffer; 198 199 private: 200 StringRef getPassName() const override { return "NVPTX Assembly Printer"; } 201 202 const Function *F; 203 std::string CurrentFnName; 204 205 void emitStartOfAsmFile(Module &M) override; 206 void emitBasicBlockStart(const MachineBasicBlock &MBB) override; 207 void emitFunctionEntryLabel() override; 208 void emitFunctionBodyStart() override; 209 void emitFunctionBodyEnd() override; 210 void emitImplicitDef(const MachineInstr *MI) const override; 211 212 void emitInstruction(const MachineInstr *) override; 213 void lowerToMCInst(const MachineInstr *MI, MCInst &OutMI); 214 bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp); 215 MCOperand GetSymbolRef(const MCSymbol *Symbol); 216 unsigned encodeVirtualRegister(unsigned Reg); 217 218 void printMemOperand(const MachineInstr *MI, int opNum, raw_ostream &O, 219 const char *Modifier = nullptr); 220 void printModuleLevelGV(const GlobalVariable *GVar, raw_ostream &O, 221 bool = false); 222 void printParamName(Function::const_arg_iterator I, int paramIndex, 223 raw_ostream &O); 224 void emitGlobals(const Module &M); 225 void emitHeader(Module &M, raw_ostream &O, const NVPTXSubtarget &STI); 226 void emitKernelFunctionDirectives(const Function &F, raw_ostream &O) const; 227 void emitVirtualRegister(unsigned int vr, raw_ostream &); 228 void emitFunctionParamList(const Function *, raw_ostream &O); 229 void emitFunctionParamList(const MachineFunction &MF, raw_ostream &O); 230 void setAndEmitFunctionVirtualRegisters(const MachineFunction &MF); 231 void printReturnValStr(const Function *, raw_ostream &O); 232 void printReturnValStr(const MachineFunction &MF, raw_ostream &O); 233 bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, 234 const char *ExtraCode, raw_ostream &) override; 235 void printOperand(const MachineInstr *MI, int opNum, raw_ostream &O); 236 bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNo, 237 const char *ExtraCode, raw_ostream &) override; 238 239 const MCExpr *lowerConstantForGV(const Constant *CV, bool ProcessingGeneric); 240 void printMCExpr(const MCExpr &Expr, raw_ostream &OS); 241 242 protected: 243 bool doInitialization(Module &M) override; 244 bool doFinalization(Module &M) override; 245 246 private: 247 bool GlobalsEmitted; 248 249 // This is specific per MachineFunction. 250 const MachineRegisterInfo *MRI; 251 // The contents are specific for each 252 // MachineFunction. But the size of the 253 // array is not. 254 typedef DenseMap<unsigned, unsigned> VRegMap; 255 typedef DenseMap<const TargetRegisterClass *, VRegMap> VRegRCMap; 256 VRegRCMap VRegMapping; 257 258 // List of variables demoted to a function scope. 259 std::map<const Function *, std::vector<const GlobalVariable *>> localDecls; 260 261 void emitPTXGlobalVariable(const GlobalVariable *GVar, raw_ostream &O); 262 void emitPTXAddressSpace(unsigned int AddressSpace, raw_ostream &O) const; 263 std::string getPTXFundamentalTypeStr(Type *Ty, bool = true) const; 264 void printScalarConstant(const Constant *CPV, raw_ostream &O); 265 void printFPConstant(const ConstantFP *Fp, raw_ostream &O); 266 void bufferLEByte(const Constant *CPV, int Bytes, AggBuffer *aggBuffer); 267 void bufferAggregateConstant(const Constant *CV, AggBuffer *aggBuffer); 268 269 void emitLinkageDirective(const GlobalValue *V, raw_ostream &O); 270 void emitDeclarations(const Module &, raw_ostream &O); 271 void emitDeclaration(const Function *, raw_ostream &O); 272 void emitDemotedVars(const Function *, raw_ostream &); 273 274 bool lowerImageHandleOperand(const MachineInstr *MI, unsigned OpNo, 275 MCOperand &MCOp); 276 void lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp); 277 278 bool isLoopHeaderOfNoUnroll(const MachineBasicBlock &MBB) const; 279 280 // Used to control the need to emit .generic() in the initializer of 281 // module scope variables. 282 // Although ptx supports the hybrid mode like the following, 283 // .global .u32 a; 284 // .global .u32 b; 285 // .global .u32 addr[] = {a, generic(b)} 286 // we have difficulty representing the difference in the NVVM IR. 287 // 288 // Since the address value should always be generic in CUDA C and always 289 // be specific in OpenCL, we use this simple control here. 290 // 291 bool EmitGeneric; 292 293 public: 294 NVPTXAsmPrinter(TargetMachine &TM, std::unique_ptr<MCStreamer> Streamer) 295 : AsmPrinter(TM, std::move(Streamer)), 296 EmitGeneric(static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == 297 NVPTX::CUDA) {} 298 299 bool runOnMachineFunction(MachineFunction &F) override; 300 301 void getAnalysisUsage(AnalysisUsage &AU) const override { 302 AU.addRequired<MachineLoopInfo>(); 303 AsmPrinter::getAnalysisUsage(AU); 304 } 305 306 std::string getVirtualRegisterName(unsigned) const; 307 308 const MCSymbol *getFunctionFrameSymbol() const override; 309 310 // Make emitGlobalVariable() no-op for NVPTX. 311 // Global variables have been already emitted by the time the base AsmPrinter 312 // attempts to do so in doFinalization() (see NVPTXAsmPrinter::emitGlobals()). 313 void emitGlobalVariable(const GlobalVariable *GV) override {} 314 }; 315 316 } // end namespace llvm 317 318 #endif // LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H 319