1 //===-- NVPTXAsmPrinter.h - NVPTX LLVM assembly writer ----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains a printer that converts from our internal representation 10 // of machine-dependent LLVM code to NVPTX assembly language. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H 15 #define LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H 16 17 #include "NVPTX.h" 18 #include "NVPTXSubtarget.h" 19 #include "NVPTXTargetMachine.h" 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/ADT/StringRef.h" 23 #include "llvm/CodeGen/AsmPrinter.h" 24 #include "llvm/CodeGen/MachineFunction.h" 25 #include "llvm/CodeGen/MachineLoopInfo.h" 26 #include "llvm/IR/Constants.h" 27 #include "llvm/IR/DebugLoc.h" 28 #include "llvm/IR/DerivedTypes.h" 29 #include "llvm/IR/Function.h" 30 #include "llvm/IR/GlobalValue.h" 31 #include "llvm/IR/Value.h" 32 #include "llvm/MC/MCExpr.h" 33 #include "llvm/MC/MCStreamer.h" 34 #include "llvm/MC/MCSymbol.h" 35 #include "llvm/Pass.h" 36 #include "llvm/Support/Casting.h" 37 #include "llvm/Support/Compiler.h" 38 #include "llvm/Support/ErrorHandling.h" 39 #include "llvm/Support/raw_ostream.h" 40 #include "llvm/Target/TargetMachine.h" 41 #include <algorithm> 42 #include <cassert> 43 #include <map> 44 #include <memory> 45 #include <string> 46 #include <vector> 47 48 // The ptx syntax and format is very different from that usually seem in a .s 49 // file, 50 // therefore we are not able to use the MCAsmStreamer interface here. 51 // 52 // We are handcrafting the output method here. 53 // 54 // A better approach is to clone the MCAsmStreamer to a MCPTXAsmStreamer 55 // (subclass of MCStreamer). 56 57 namespace llvm { 58 59 class MCOperand; 60 61 class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter { 62 63 class AggBuffer { 64 // Used to buffer the emitted string for initializing global 65 // aggregates. 66 // 67 // Normally an aggregate (array, vector or structure) is emitted 68 // as a u8[]. However, if one element/field of the aggregate 69 // is a non-NULL address, then the aggregate is emitted as u32[] 70 // or u64[]. 71 // 72 // We first layout the aggregate in 'buffer' in bytes, except for 73 // those symbol addresses. For the i-th symbol address in the 74 //aggregate, its corresponding 4-byte or 8-byte elements in 'buffer' 75 // are filled with 0s. symbolPosInBuffer[i-1] records its position 76 // in 'buffer', and Symbols[i-1] records the Value*. 77 // 78 // Once we have this AggBuffer setup, we can choose how to print 79 // it out. 80 public: 81 unsigned numSymbols; // number of symbol addresses 82 83 private: 84 const unsigned size; // size of the buffer in bytes 85 std::vector<unsigned char> buffer; // the buffer 86 SmallVector<unsigned, 4> symbolPosInBuffer; 87 SmallVector<const Value *, 4> Symbols; 88 // SymbolsBeforeStripping[i] is the original form of Symbols[i] before 89 // stripping pointer casts, i.e., 90 // Symbols[i] == SymbolsBeforeStripping[i]->stripPointerCasts(). 91 // 92 // We need to keep these values because AggBuffer::print decides whether to 93 // emit a "generic()" cast for Symbols[i] depending on the address space of 94 // SymbolsBeforeStripping[i]. 95 SmallVector<const Value *, 4> SymbolsBeforeStripping; 96 unsigned curpos; 97 raw_ostream &O; 98 NVPTXAsmPrinter &AP; 99 bool EmitGeneric; 100 101 public: 102 AggBuffer(unsigned size, raw_ostream &O, NVPTXAsmPrinter &AP) 103 : size(size), buffer(size), O(O), AP(AP) { 104 curpos = 0; 105 numSymbols = 0; 106 EmitGeneric = AP.EmitGeneric; 107 } 108 109 unsigned addBytes(unsigned char *Ptr, int Num, int Bytes) { 110 assert((curpos + Num) <= size); 111 assert((curpos + Bytes) <= size); 112 for (int i = 0; i < Num; ++i) { 113 buffer[curpos] = Ptr[i]; 114 curpos++; 115 } 116 for (int i = Num; i < Bytes; ++i) { 117 buffer[curpos] = 0; 118 curpos++; 119 } 120 return curpos; 121 } 122 123 unsigned addZeros(int Num) { 124 assert((curpos + Num) <= size); 125 for (int i = 0; i < Num; ++i) { 126 buffer[curpos] = 0; 127 curpos++; 128 } 129 return curpos; 130 } 131 132 void addSymbol(const Value *GVar, const Value *GVarBeforeStripping) { 133 symbolPosInBuffer.push_back(curpos); 134 Symbols.push_back(GVar); 135 SymbolsBeforeStripping.push_back(GVarBeforeStripping); 136 numSymbols++; 137 } 138 139 void print() { 140 if (numSymbols == 0) { 141 // print out in bytes 142 for (unsigned i = 0; i < size; i++) { 143 if (i) 144 O << ", "; 145 O << (unsigned int) buffer[i]; 146 } 147 } else { 148 // print out in 4-bytes or 8-bytes 149 unsigned int pos = 0; 150 unsigned int nSym = 0; 151 unsigned int nextSymbolPos = symbolPosInBuffer[nSym]; 152 unsigned int nBytes = 4; 153 if (static_cast<const NVPTXTargetMachine &>(AP.TM).is64Bit()) 154 nBytes = 8; 155 for (pos = 0; pos < size; pos += nBytes) { 156 if (pos) 157 O << ", "; 158 if (pos == nextSymbolPos) { 159 const Value *v = Symbols[nSym]; 160 const Value *v0 = SymbolsBeforeStripping[nSym]; 161 if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) { 162 MCSymbol *Name = AP.getSymbol(GVar); 163 PointerType *PTy = dyn_cast<PointerType>(v0->getType()); 164 bool IsNonGenericPointer = false; // Is v0 a non-generic pointer? 165 if (PTy && PTy->getAddressSpace() != 0) { 166 IsNonGenericPointer = true; 167 } 168 if (EmitGeneric && !isa<Function>(v) && !IsNonGenericPointer) { 169 O << "generic("; 170 Name->print(O, AP.MAI); 171 O << ")"; 172 } else { 173 Name->print(O, AP.MAI); 174 } 175 } else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) { 176 const MCExpr *Expr = 177 AP.lowerConstantForGV(cast<Constant>(CExpr), false); 178 AP.printMCExpr(*Expr, O); 179 } else 180 llvm_unreachable("symbol type unknown"); 181 nSym++; 182 if (nSym >= numSymbols) 183 nextSymbolPos = size + 1; 184 else 185 nextSymbolPos = symbolPosInBuffer[nSym]; 186 } else if (nBytes == 4) 187 O << *(unsigned int *)(&buffer[pos]); 188 else 189 O << *(unsigned long long *)(&buffer[pos]); 190 } 191 } 192 } 193 }; 194 195 friend class AggBuffer; 196 197 private: 198 StringRef getPassName() const override { return "NVPTX Assembly Printer"; } 199 200 const Function *F; 201 std::string CurrentFnName; 202 203 void emitStartOfAsmFile(Module &M) override; 204 void emitBasicBlockStart(const MachineBasicBlock &MBB) override; 205 void emitFunctionEntryLabel() override; 206 void emitFunctionBodyStart() override; 207 void emitFunctionBodyEnd() override; 208 void emitImplicitDef(const MachineInstr *MI) const override; 209 210 void emitInstruction(const MachineInstr *) override; 211 void lowerToMCInst(const MachineInstr *MI, MCInst &OutMI); 212 bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp); 213 MCOperand GetSymbolRef(const MCSymbol *Symbol); 214 unsigned encodeVirtualRegister(unsigned Reg); 215 216 void printMemOperand(const MachineInstr *MI, int opNum, raw_ostream &O, 217 const char *Modifier = nullptr); 218 void printModuleLevelGV(const GlobalVariable *GVar, raw_ostream &O, 219 bool = false); 220 void printParamName(Function::const_arg_iterator I, int paramIndex, 221 raw_ostream &O); 222 void emitGlobals(const Module &M); 223 void emitHeader(Module &M, raw_ostream &O, const NVPTXSubtarget &STI); 224 void emitKernelFunctionDirectives(const Function &F, raw_ostream &O) const; 225 void emitVirtualRegister(unsigned int vr, raw_ostream &); 226 void emitFunctionParamList(const Function *, raw_ostream &O); 227 void emitFunctionParamList(const MachineFunction &MF, raw_ostream &O); 228 void setAndEmitFunctionVirtualRegisters(const MachineFunction &MF); 229 void printReturnValStr(const Function *, raw_ostream &O); 230 void printReturnValStr(const MachineFunction &MF, raw_ostream &O); 231 bool PrintAsmOperand(const MachineInstr *MI, unsigned OpNo, 232 const char *ExtraCode, raw_ostream &) override; 233 void printOperand(const MachineInstr *MI, int opNum, raw_ostream &O); 234 bool PrintAsmMemoryOperand(const MachineInstr *MI, unsigned OpNo, 235 const char *ExtraCode, raw_ostream &) override; 236 237 const MCExpr *lowerConstantForGV(const Constant *CV, bool ProcessingGeneric); 238 void printMCExpr(const MCExpr &Expr, raw_ostream &OS); 239 240 protected: 241 bool doInitialization(Module &M) override; 242 bool doFinalization(Module &M) override; 243 244 private: 245 bool GlobalsEmitted; 246 247 // This is specific per MachineFunction. 248 const MachineRegisterInfo *MRI; 249 // The contents are specific for each 250 // MachineFunction. But the size of the 251 // array is not. 252 typedef DenseMap<unsigned, unsigned> VRegMap; 253 typedef DenseMap<const TargetRegisterClass *, VRegMap> VRegRCMap; 254 VRegRCMap VRegMapping; 255 256 // List of variables demoted to a function scope. 257 std::map<const Function *, std::vector<const GlobalVariable *>> localDecls; 258 259 void emitPTXGlobalVariable(const GlobalVariable *GVar, raw_ostream &O); 260 void emitPTXAddressSpace(unsigned int AddressSpace, raw_ostream &O) const; 261 std::string getPTXFundamentalTypeStr(Type *Ty, bool = true) const; 262 void printScalarConstant(const Constant *CPV, raw_ostream &O); 263 void printFPConstant(const ConstantFP *Fp, raw_ostream &O); 264 void bufferLEByte(const Constant *CPV, int Bytes, AggBuffer *aggBuffer); 265 void bufferAggregateConstant(const Constant *CV, AggBuffer *aggBuffer); 266 267 void emitLinkageDirective(const GlobalValue *V, raw_ostream &O); 268 void emitDeclarations(const Module &, raw_ostream &O); 269 void emitDeclaration(const Function *, raw_ostream &O); 270 void emitDemotedVars(const Function *, raw_ostream &); 271 272 bool lowerImageHandleOperand(const MachineInstr *MI, unsigned OpNo, 273 MCOperand &MCOp); 274 void lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp); 275 276 bool isLoopHeaderOfNoUnroll(const MachineBasicBlock &MBB) const; 277 278 // Used to control the need to emit .generic() in the initializer of 279 // module scope variables. 280 // Although ptx supports the hybrid mode like the following, 281 // .global .u32 a; 282 // .global .u32 b; 283 // .global .u32 addr[] = {a, generic(b)} 284 // we have difficulty representing the difference in the NVVM IR. 285 // 286 // Since the address value should always be generic in CUDA C and always 287 // be specific in OpenCL, we use this simple control here. 288 // 289 bool EmitGeneric; 290 291 public: 292 NVPTXAsmPrinter(TargetMachine &TM, std::unique_ptr<MCStreamer> Streamer) 293 : AsmPrinter(TM, std::move(Streamer)), 294 EmitGeneric(static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == 295 NVPTX::CUDA) {} 296 297 bool runOnMachineFunction(MachineFunction &F) override; 298 299 void getAnalysisUsage(AnalysisUsage &AU) const override { 300 AU.addRequired<MachineLoopInfo>(); 301 AsmPrinter::getAnalysisUsage(AU); 302 } 303 304 std::string getVirtualRegisterName(unsigned) const; 305 306 const MCSymbol *getFunctionFrameSymbol() const override; 307 }; 308 309 } // end namespace llvm 310 311 #endif // LLVM_LIB_TARGET_NVPTX_NVPTXASMPRINTER_H 312