1 //===-- NVPTXTargetTransformInfo.cpp - NVPTX specific TTI -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "NVPTXTargetTransformInfo.h" 10 #include "NVPTXUtilities.h" 11 #include "llvm/Analysis/LoopInfo.h" 12 #include "llvm/Analysis/TargetTransformInfo.h" 13 #include "llvm/Analysis/ValueTracking.h" 14 #include "llvm/CodeGen/BasicTTIImpl.h" 15 #include "llvm/CodeGen/CostTable.h" 16 #include "llvm/CodeGen/TargetLowering.h" 17 #include "llvm/IR/IntrinsicsNVPTX.h" 18 #include "llvm/Support/Debug.h" 19 #include <optional> 20 using namespace llvm; 21 22 #define DEBUG_TYPE "NVPTXtti" 23 24 // Whether the given intrinsic reads threadIdx.x/y/z. 25 static bool readsThreadIndex(const IntrinsicInst *II) { 26 switch (II->getIntrinsicID()) { 27 default: return false; 28 case Intrinsic::nvvm_read_ptx_sreg_tid_x: 29 case Intrinsic::nvvm_read_ptx_sreg_tid_y: 30 case Intrinsic::nvvm_read_ptx_sreg_tid_z: 31 return true; 32 } 33 } 34 35 static bool readsLaneId(const IntrinsicInst *II) { 36 return II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_laneid; 37 } 38 39 // Whether the given intrinsic is an atomic instruction in PTX. 40 static bool isNVVMAtomic(const IntrinsicInst *II) { 41 switch (II->getIntrinsicID()) { 42 default: return false; 43 case Intrinsic::nvvm_atomic_load_inc_32: 44 case Intrinsic::nvvm_atomic_load_dec_32: 45 46 case Intrinsic::nvvm_atomic_add_gen_f_cta: 47 case Intrinsic::nvvm_atomic_add_gen_f_sys: 48 case Intrinsic::nvvm_atomic_add_gen_i_cta: 49 case Intrinsic::nvvm_atomic_add_gen_i_sys: 50 case Intrinsic::nvvm_atomic_and_gen_i_cta: 51 case Intrinsic::nvvm_atomic_and_gen_i_sys: 52 case Intrinsic::nvvm_atomic_cas_gen_i_cta: 53 case Intrinsic::nvvm_atomic_cas_gen_i_sys: 54 case Intrinsic::nvvm_atomic_dec_gen_i_cta: 55 case Intrinsic::nvvm_atomic_dec_gen_i_sys: 56 case Intrinsic::nvvm_atomic_inc_gen_i_cta: 57 case Intrinsic::nvvm_atomic_inc_gen_i_sys: 58 case Intrinsic::nvvm_atomic_max_gen_i_cta: 59 case Intrinsic::nvvm_atomic_max_gen_i_sys: 60 case Intrinsic::nvvm_atomic_min_gen_i_cta: 61 case Intrinsic::nvvm_atomic_min_gen_i_sys: 62 case Intrinsic::nvvm_atomic_or_gen_i_cta: 63 case Intrinsic::nvvm_atomic_or_gen_i_sys: 64 case Intrinsic::nvvm_atomic_exch_gen_i_cta: 65 case Intrinsic::nvvm_atomic_exch_gen_i_sys: 66 case Intrinsic::nvvm_atomic_xor_gen_i_cta: 67 case Intrinsic::nvvm_atomic_xor_gen_i_sys: 68 return true; 69 } 70 } 71 72 bool NVPTXTTIImpl::isSourceOfDivergence(const Value *V) { 73 // Without inter-procedural analysis, we conservatively assume that arguments 74 // to __device__ functions are divergent. 75 if (const Argument *Arg = dyn_cast<Argument>(V)) 76 return !isKernelFunction(*Arg->getParent()); 77 78 if (const Instruction *I = dyn_cast<Instruction>(V)) { 79 // Without pointer analysis, we conservatively assume values loaded from 80 // generic or local address space are divergent. 81 if (const LoadInst *LI = dyn_cast<LoadInst>(I)) { 82 unsigned AS = LI->getPointerAddressSpace(); 83 return AS == ADDRESS_SPACE_GENERIC || AS == ADDRESS_SPACE_LOCAL; 84 } 85 // Atomic instructions may cause divergence. Atomic instructions are 86 // executed sequentially across all threads in a warp. Therefore, an earlier 87 // executed thread may see different memory inputs than a later executed 88 // thread. For example, suppose *a = 0 initially. 89 // 90 // atom.global.add.s32 d, [a], 1 91 // 92 // returns 0 for the first thread that enters the critical region, and 1 for 93 // the second thread. 94 if (I->isAtomic()) 95 return true; 96 if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { 97 // Instructions that read threadIdx are obviously divergent. 98 if (readsThreadIndex(II) || readsLaneId(II)) 99 return true; 100 // Handle the NVPTX atomic intrinsics that cannot be represented as an 101 // atomic IR instruction. 102 if (isNVVMAtomic(II)) 103 return true; 104 } 105 // Conservatively consider the return value of function calls as divergent. 106 // We could analyze callees with bodies more precisely using 107 // inter-procedural analysis. 108 if (isa<CallInst>(I)) 109 return true; 110 } 111 112 return false; 113 } 114 115 // Convert NVVM intrinsics to target-generic LLVM code where possible. 116 static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) { 117 // Each NVVM intrinsic we can simplify can be replaced with one of: 118 // 119 // * an LLVM intrinsic, 120 // * an LLVM cast operation, 121 // * an LLVM binary operation, or 122 // * ad-hoc LLVM IR for the particular operation. 123 124 // Some transformations are only valid when the module's 125 // flush-denormals-to-zero (ftz) setting is true/false, whereas other 126 // transformations are valid regardless of the module's ftz setting. 127 enum FtzRequirementTy { 128 FTZ_Any, // Any ftz setting is ok. 129 FTZ_MustBeOn, // Transformation is valid only if ftz is on. 130 FTZ_MustBeOff, // Transformation is valid only if ftz is off. 131 }; 132 // Classes of NVVM intrinsics that can't be replaced one-to-one with a 133 // target-generic intrinsic, cast op, or binary op but that we can nonetheless 134 // simplify. 135 enum SpecialCase { 136 SPC_Reciprocal, 137 }; 138 139 // SimplifyAction is a poor-man's variant (plus an additional flag) that 140 // represents how to replace an NVVM intrinsic with target-generic LLVM IR. 141 struct SimplifyAction { 142 // Invariant: At most one of these Optionals has a value. 143 std::optional<Intrinsic::ID> IID; 144 std::optional<Instruction::CastOps> CastOp; 145 std::optional<Instruction::BinaryOps> BinaryOp; 146 std::optional<SpecialCase> Special; 147 148 FtzRequirementTy FtzRequirement = FTZ_Any; 149 // Denormal handling is guarded by different attributes depending on the 150 // type (denormal-fp-math vs denormal-fp-math-f32), take note of halfs. 151 bool IsHalfTy = false; 152 153 SimplifyAction() = default; 154 155 SimplifyAction(Intrinsic::ID IID, FtzRequirementTy FtzReq, 156 bool IsHalfTy = false) 157 : IID(IID), FtzRequirement(FtzReq), IsHalfTy(IsHalfTy) {} 158 159 // Cast operations don't have anything to do with FTZ, so we skip that 160 // argument. 161 SimplifyAction(Instruction::CastOps CastOp) : CastOp(CastOp) {} 162 163 SimplifyAction(Instruction::BinaryOps BinaryOp, FtzRequirementTy FtzReq) 164 : BinaryOp(BinaryOp), FtzRequirement(FtzReq) {} 165 166 SimplifyAction(SpecialCase Special, FtzRequirementTy FtzReq) 167 : Special(Special), FtzRequirement(FtzReq) {} 168 }; 169 170 // Try to generate a SimplifyAction describing how to replace our 171 // IntrinsicInstr with target-generic LLVM IR. 172 const SimplifyAction Action = [II]() -> SimplifyAction { 173 switch (II->getIntrinsicID()) { 174 // NVVM intrinsics that map directly to LLVM intrinsics. 175 case Intrinsic::nvvm_ceil_d: 176 return {Intrinsic::ceil, FTZ_Any}; 177 case Intrinsic::nvvm_ceil_f: 178 return {Intrinsic::ceil, FTZ_MustBeOff}; 179 case Intrinsic::nvvm_ceil_ftz_f: 180 return {Intrinsic::ceil, FTZ_MustBeOn}; 181 case Intrinsic::nvvm_fabs_d: 182 return {Intrinsic::fabs, FTZ_Any}; 183 case Intrinsic::nvvm_fabs_f: 184 return {Intrinsic::fabs, FTZ_MustBeOff}; 185 case Intrinsic::nvvm_fabs_ftz_f: 186 return {Intrinsic::fabs, FTZ_MustBeOn}; 187 case Intrinsic::nvvm_floor_d: 188 return {Intrinsic::floor, FTZ_Any}; 189 case Intrinsic::nvvm_floor_f: 190 return {Intrinsic::floor, FTZ_MustBeOff}; 191 case Intrinsic::nvvm_floor_ftz_f: 192 return {Intrinsic::floor, FTZ_MustBeOn}; 193 case Intrinsic::nvvm_fma_rn_d: 194 return {Intrinsic::fma, FTZ_Any}; 195 case Intrinsic::nvvm_fma_rn_f: 196 return {Intrinsic::fma, FTZ_MustBeOff}; 197 case Intrinsic::nvvm_fma_rn_ftz_f: 198 return {Intrinsic::fma, FTZ_MustBeOn}; 199 case Intrinsic::nvvm_fma_rn_f16: 200 return {Intrinsic::fma, FTZ_MustBeOff, true}; 201 case Intrinsic::nvvm_fma_rn_ftz_f16: 202 return {Intrinsic::fma, FTZ_MustBeOn, true}; 203 case Intrinsic::nvvm_fma_rn_f16x2: 204 return {Intrinsic::fma, FTZ_MustBeOff, true}; 205 case Intrinsic::nvvm_fma_rn_ftz_f16x2: 206 return {Intrinsic::fma, FTZ_MustBeOn, true}; 207 case Intrinsic::nvvm_fmax_d: 208 return {Intrinsic::maxnum, FTZ_Any}; 209 case Intrinsic::nvvm_fmax_f: 210 return {Intrinsic::maxnum, FTZ_MustBeOff}; 211 case Intrinsic::nvvm_fmax_ftz_f: 212 return {Intrinsic::maxnum, FTZ_MustBeOn}; 213 case Intrinsic::nvvm_fmax_nan_f: 214 return {Intrinsic::maximum, FTZ_MustBeOff}; 215 case Intrinsic::nvvm_fmax_ftz_nan_f: 216 return {Intrinsic::maximum, FTZ_MustBeOn}; 217 case Intrinsic::nvvm_fmax_f16: 218 return {Intrinsic::maxnum, FTZ_MustBeOff, true}; 219 case Intrinsic::nvvm_fmax_ftz_f16: 220 return {Intrinsic::maxnum, FTZ_MustBeOn, true}; 221 case Intrinsic::nvvm_fmax_f16x2: 222 return {Intrinsic::maxnum, FTZ_MustBeOff, true}; 223 case Intrinsic::nvvm_fmax_ftz_f16x2: 224 return {Intrinsic::maxnum, FTZ_MustBeOn, true}; 225 case Intrinsic::nvvm_fmax_nan_f16: 226 return {Intrinsic::maximum, FTZ_MustBeOff, true}; 227 case Intrinsic::nvvm_fmax_ftz_nan_f16: 228 return {Intrinsic::maximum, FTZ_MustBeOn, true}; 229 case Intrinsic::nvvm_fmax_nan_f16x2: 230 return {Intrinsic::maximum, FTZ_MustBeOff, true}; 231 case Intrinsic::nvvm_fmax_ftz_nan_f16x2: 232 return {Intrinsic::maximum, FTZ_MustBeOn, true}; 233 case Intrinsic::nvvm_fmin_d: 234 return {Intrinsic::minnum, FTZ_Any}; 235 case Intrinsic::nvvm_fmin_f: 236 return {Intrinsic::minnum, FTZ_MustBeOff}; 237 case Intrinsic::nvvm_fmin_ftz_f: 238 return {Intrinsic::minnum, FTZ_MustBeOn}; 239 case Intrinsic::nvvm_fmin_nan_f: 240 return {Intrinsic::minimum, FTZ_MustBeOff}; 241 case Intrinsic::nvvm_fmin_ftz_nan_f: 242 return {Intrinsic::minimum, FTZ_MustBeOn}; 243 case Intrinsic::nvvm_fmin_f16: 244 return {Intrinsic::minnum, FTZ_MustBeOff, true}; 245 case Intrinsic::nvvm_fmin_ftz_f16: 246 return {Intrinsic::minnum, FTZ_MustBeOn, true}; 247 case Intrinsic::nvvm_fmin_f16x2: 248 return {Intrinsic::minnum, FTZ_MustBeOff, true}; 249 case Intrinsic::nvvm_fmin_ftz_f16x2: 250 return {Intrinsic::minnum, FTZ_MustBeOn, true}; 251 case Intrinsic::nvvm_fmin_nan_f16: 252 return {Intrinsic::minimum, FTZ_MustBeOff, true}; 253 case Intrinsic::nvvm_fmin_ftz_nan_f16: 254 return {Intrinsic::minimum, FTZ_MustBeOn, true}; 255 case Intrinsic::nvvm_fmin_nan_f16x2: 256 return {Intrinsic::minimum, FTZ_MustBeOff, true}; 257 case Intrinsic::nvvm_fmin_ftz_nan_f16x2: 258 return {Intrinsic::minimum, FTZ_MustBeOn, true}; 259 case Intrinsic::nvvm_round_d: 260 return {Intrinsic::round, FTZ_Any}; 261 case Intrinsic::nvvm_round_f: 262 return {Intrinsic::round, FTZ_MustBeOff}; 263 case Intrinsic::nvvm_round_ftz_f: 264 return {Intrinsic::round, FTZ_MustBeOn}; 265 case Intrinsic::nvvm_sqrt_rn_d: 266 return {Intrinsic::sqrt, FTZ_Any}; 267 case Intrinsic::nvvm_sqrt_f: 268 // nvvm_sqrt_f is a special case. For most intrinsics, foo_ftz_f is the 269 // ftz version, and foo_f is the non-ftz version. But nvvm_sqrt_f adopts 270 // the ftz-ness of the surrounding code. sqrt_rn_f and sqrt_rn_ftz_f are 271 // the versions with explicit ftz-ness. 272 return {Intrinsic::sqrt, FTZ_Any}; 273 case Intrinsic::nvvm_sqrt_rn_f: 274 return {Intrinsic::sqrt, FTZ_MustBeOff}; 275 case Intrinsic::nvvm_sqrt_rn_ftz_f: 276 return {Intrinsic::sqrt, FTZ_MustBeOn}; 277 case Intrinsic::nvvm_trunc_d: 278 return {Intrinsic::trunc, FTZ_Any}; 279 case Intrinsic::nvvm_trunc_f: 280 return {Intrinsic::trunc, FTZ_MustBeOff}; 281 case Intrinsic::nvvm_trunc_ftz_f: 282 return {Intrinsic::trunc, FTZ_MustBeOn}; 283 284 // NVVM intrinsics that map to LLVM cast operations. 285 // 286 // Note that llvm's target-generic conversion operators correspond to the rz 287 // (round to zero) versions of the nvvm conversion intrinsics, even though 288 // most everything else here uses the rn (round to nearest even) nvvm ops. 289 case Intrinsic::nvvm_d2i_rz: 290 case Intrinsic::nvvm_f2i_rz: 291 case Intrinsic::nvvm_d2ll_rz: 292 case Intrinsic::nvvm_f2ll_rz: 293 return {Instruction::FPToSI}; 294 case Intrinsic::nvvm_d2ui_rz: 295 case Intrinsic::nvvm_f2ui_rz: 296 case Intrinsic::nvvm_d2ull_rz: 297 case Intrinsic::nvvm_f2ull_rz: 298 return {Instruction::FPToUI}; 299 case Intrinsic::nvvm_i2d_rz: 300 case Intrinsic::nvvm_i2f_rz: 301 case Intrinsic::nvvm_ll2d_rz: 302 case Intrinsic::nvvm_ll2f_rz: 303 return {Instruction::SIToFP}; 304 case Intrinsic::nvvm_ui2d_rz: 305 case Intrinsic::nvvm_ui2f_rz: 306 case Intrinsic::nvvm_ull2d_rz: 307 case Intrinsic::nvvm_ull2f_rz: 308 return {Instruction::UIToFP}; 309 310 // NVVM intrinsics that map to LLVM binary ops. 311 case Intrinsic::nvvm_add_rn_d: 312 return {Instruction::FAdd, FTZ_Any}; 313 case Intrinsic::nvvm_add_rn_f: 314 return {Instruction::FAdd, FTZ_MustBeOff}; 315 case Intrinsic::nvvm_add_rn_ftz_f: 316 return {Instruction::FAdd, FTZ_MustBeOn}; 317 case Intrinsic::nvvm_mul_rn_d: 318 return {Instruction::FMul, FTZ_Any}; 319 case Intrinsic::nvvm_mul_rn_f: 320 return {Instruction::FMul, FTZ_MustBeOff}; 321 case Intrinsic::nvvm_mul_rn_ftz_f: 322 return {Instruction::FMul, FTZ_MustBeOn}; 323 case Intrinsic::nvvm_div_rn_d: 324 return {Instruction::FDiv, FTZ_Any}; 325 case Intrinsic::nvvm_div_rn_f: 326 return {Instruction::FDiv, FTZ_MustBeOff}; 327 case Intrinsic::nvvm_div_rn_ftz_f: 328 return {Instruction::FDiv, FTZ_MustBeOn}; 329 330 // The remainder of cases are NVVM intrinsics that map to LLVM idioms, but 331 // need special handling. 332 // 333 // We seem to be missing intrinsics for rcp.approx.{ftz.}f32, which is just 334 // as well. 335 case Intrinsic::nvvm_rcp_rn_d: 336 return {SPC_Reciprocal, FTZ_Any}; 337 case Intrinsic::nvvm_rcp_rn_f: 338 return {SPC_Reciprocal, FTZ_MustBeOff}; 339 case Intrinsic::nvvm_rcp_rn_ftz_f: 340 return {SPC_Reciprocal, FTZ_MustBeOn}; 341 342 // We do not currently simplify intrinsics that give an approximate 343 // answer. These include: 344 // 345 // - nvvm_cos_approx_{f,ftz_f} 346 // - nvvm_ex2_approx_{d,f,ftz_f} 347 // - nvvm_lg2_approx_{d,f,ftz_f} 348 // - nvvm_sin_approx_{f,ftz_f} 349 // - nvvm_sqrt_approx_{f,ftz_f} 350 // - nvvm_rsqrt_approx_{d,f,ftz_f} 351 // - nvvm_div_approx_{ftz_d,ftz_f,f} 352 // - nvvm_rcp_approx_ftz_d 353 // 354 // Ideally we'd encode them as e.g. "fast call @llvm.cos", where "fast" 355 // means that fastmath is enabled in the intrinsic. Unfortunately only 356 // binary operators (currently) have a fastmath bit in SelectionDAG, so 357 // this information gets lost and we can't select on it. 358 // 359 // TODO: div and rcp are lowered to a binary op, so these we could in 360 // theory lower them to "fast fdiv". 361 362 default: 363 return {}; 364 } 365 }(); 366 367 // If Action.FtzRequirementTy is not satisfied by the module's ftz state, we 368 // can bail out now. (Notice that in the case that IID is not an NVVM 369 // intrinsic, we don't have to look up any module metadata, as 370 // FtzRequirementTy will be FTZ_Any.) 371 if (Action.FtzRequirement != FTZ_Any) { 372 // FIXME: Broken for f64 373 DenormalMode Mode = II->getFunction()->getDenormalMode( 374 Action.IsHalfTy ? APFloat::IEEEhalf() : APFloat::IEEEsingle()); 375 bool FtzEnabled = Mode.Output == DenormalMode::PreserveSign; 376 377 if (FtzEnabled != (Action.FtzRequirement == FTZ_MustBeOn)) 378 return nullptr; 379 } 380 381 // Simplify to target-generic intrinsic. 382 if (Action.IID) { 383 SmallVector<Value *, 4> Args(II->args()); 384 // All the target-generic intrinsics currently of interest to us have one 385 // type argument, equal to that of the nvvm intrinsic's argument. 386 Type *Tys[] = {II->getArgOperand(0)->getType()}; 387 return CallInst::Create( 388 Intrinsic::getDeclaration(II->getModule(), *Action.IID, Tys), Args); 389 } 390 391 // Simplify to target-generic binary op. 392 if (Action.BinaryOp) 393 return BinaryOperator::Create(*Action.BinaryOp, II->getArgOperand(0), 394 II->getArgOperand(1), II->getName()); 395 396 // Simplify to target-generic cast op. 397 if (Action.CastOp) 398 return CastInst::Create(*Action.CastOp, II->getArgOperand(0), II->getType(), 399 II->getName()); 400 401 // All that's left are the special cases. 402 if (!Action.Special) 403 return nullptr; 404 405 switch (*Action.Special) { 406 case SPC_Reciprocal: 407 // Simplify reciprocal. 408 return BinaryOperator::Create( 409 Instruction::FDiv, ConstantFP::get(II->getArgOperand(0)->getType(), 1), 410 II->getArgOperand(0), II->getName()); 411 } 412 llvm_unreachable("All SpecialCase enumerators should be handled in switch."); 413 } 414 415 std::optional<Instruction *> 416 NVPTXTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const { 417 if (Instruction *I = simplifyNvvmIntrinsic(&II, IC)) { 418 return I; 419 } 420 return std::nullopt; 421 } 422 423 InstructionCost NVPTXTTIImpl::getArithmeticInstrCost( 424 unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind, 425 TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info, 426 ArrayRef<const Value *> Args, 427 const Instruction *CxtI) { 428 // Legalize the type. 429 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty); 430 431 int ISD = TLI->InstructionOpcodeToISD(Opcode); 432 433 switch (ISD) { 434 default: 435 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, 436 Op2Info); 437 case ISD::ADD: 438 case ISD::MUL: 439 case ISD::XOR: 440 case ISD::OR: 441 case ISD::AND: 442 // The machine code (SASS) simulates an i64 with two i32. Therefore, we 443 // estimate that arithmetic operations on i64 are twice as expensive as 444 // those on types that can fit into one machine register. 445 if (LT.second.SimpleTy == MVT::i64) 446 return 2 * LT.first; 447 // Delegate other cases to the basic TTI. 448 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, 449 Op2Info); 450 } 451 } 452 453 void NVPTXTTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE, 454 TTI::UnrollingPreferences &UP, 455 OptimizationRemarkEmitter *ORE) { 456 BaseT::getUnrollingPreferences(L, SE, UP, ORE); 457 458 // Enable partial unrolling and runtime unrolling, but reduce the 459 // threshold. This partially unrolls small loops which are often 460 // unrolled by the PTX to SASS compiler and unrolling earlier can be 461 // beneficial. 462 UP.Partial = UP.Runtime = true; 463 UP.PartialThreshold = UP.Threshold / 4; 464 } 465 466 void NVPTXTTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE, 467 TTI::PeelingPreferences &PP) { 468 BaseT::getPeelingPreferences(L, SE, PP); 469 } 470