1 //===- AMDGPUSplitModule.cpp ----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file Implements a module splitting algorithm designed to support the 10 /// FullLTO --lto-partitions option for parallel codegen. This is completely 11 /// different from the common SplitModule pass, as this system is designed with 12 /// AMDGPU in mind. 13 /// 14 /// The basic idea of this module splitting implementation is the same as 15 /// SplitModule: load-balance the module's functions across a set of N 16 /// partitions to allow parallel codegen. However, it does it very 17 /// differently than the target-agnostic variant: 18 /// - The module has "split roots", which are kernels in the vast 19 // majority of cases. 20 /// - Each root has a set of dependencies, and when a root and its 21 /// dependencies is considered "big", we try to put it in a partition where 22 /// most dependencies are already imported, to avoid duplicating large 23 /// amounts of code. 24 /// - There's special care for indirect calls in order to ensure 25 /// AMDGPUResourceUsageAnalysis can work correctly. 26 /// 27 /// This file also includes a more elaborate logging system to enable 28 /// users to easily generate logs that (if desired) do not include any value 29 /// names, in order to not leak information about the source file. 30 /// Such logs are very helpful to understand and fix potential issues with 31 /// module splitting. 32 33 #include "AMDGPUSplitModule.h" 34 #include "AMDGPUTargetMachine.h" 35 #include "Utils/AMDGPUBaseInfo.h" 36 #include "llvm/ADT/DenseMap.h" 37 #include "llvm/ADT/SmallVector.h" 38 #include "llvm/ADT/StringExtras.h" 39 #include "llvm/ADT/StringRef.h" 40 #include "llvm/Analysis/CallGraph.h" 41 #include "llvm/Analysis/TargetTransformInfo.h" 42 #include "llvm/IR/Function.h" 43 #include "llvm/IR/Instruction.h" 44 #include "llvm/IR/Module.h" 45 #include "llvm/IR/User.h" 46 #include "llvm/IR/Value.h" 47 #include "llvm/Support/Casting.h" 48 #include "llvm/Support/Debug.h" 49 #include "llvm/Support/FileSystem.h" 50 #include "llvm/Support/Path.h" 51 #include "llvm/Support/Process.h" 52 #include "llvm/Support/SHA256.h" 53 #include "llvm/Support/Threading.h" 54 #include "llvm/Support/raw_ostream.h" 55 #include "llvm/Transforms/Utils/Cloning.h" 56 #include <algorithm> 57 #include <cassert> 58 #include <iterator> 59 #include <memory> 60 #include <utility> 61 #include <vector> 62 63 using namespace llvm; 64 65 #define DEBUG_TYPE "amdgpu-split-module" 66 67 namespace { 68 69 static cl::opt<float> LargeFnFactor( 70 "amdgpu-module-splitting-large-function-threshold", cl::init(2.0f), 71 cl::Hidden, 72 cl::desc( 73 "consider a function as large and needing special treatment when the " 74 "cost of importing it into a partition" 75 "exceeds the average cost of a partition by this factor; e;g. 2.0 " 76 "means if the function and its dependencies is 2 times bigger than " 77 "an average partition; 0 disables large functions handling entirely")); 78 79 static cl::opt<float> LargeFnOverlapForMerge( 80 "amdgpu-module-splitting-large-function-merge-overlap", cl::init(0.8f), 81 cl::Hidden, 82 cl::desc( 83 "defines how much overlap between two large function's dependencies " 84 "is needed to put them in the same partition")); 85 86 static cl::opt<bool> NoExternalizeGlobals( 87 "amdgpu-module-splitting-no-externalize-globals", cl::Hidden, 88 cl::desc("disables externalization of global variable with local linkage; " 89 "may cause globals to be duplicated which increases binary size")); 90 91 static cl::opt<std::string> 92 LogDirOpt("amdgpu-module-splitting-log-dir", cl::Hidden, 93 cl::desc("output directory for AMDGPU module splitting logs")); 94 95 static cl::opt<bool> 96 LogPrivate("amdgpu-module-splitting-log-private", cl::Hidden, 97 cl::desc("hash value names before printing them in the AMDGPU " 98 "module splitting logs")); 99 100 using CostType = InstructionCost::CostType; 101 using PartitionID = unsigned; 102 using GetTTIFn = function_ref<const TargetTransformInfo &(Function &)>; 103 104 static bool isEntryPoint(const Function *F) { 105 return AMDGPU::isEntryFunctionCC(F->getCallingConv()); 106 } 107 108 static std::string getName(const Value &V) { 109 static bool HideNames; 110 111 static llvm::once_flag HideNameInitFlag; 112 llvm::call_once(HideNameInitFlag, [&]() { 113 if (LogPrivate.getNumOccurrences()) 114 HideNames = LogPrivate; 115 else { 116 const auto EV = sys::Process::GetEnv("AMD_SPLIT_MODULE_LOG_PRIVATE"); 117 HideNames = (EV.value_or("0") != "0"); 118 } 119 }); 120 121 if (!HideNames) 122 return V.getName().str(); 123 return toHex(SHA256::hash(arrayRefFromStringRef(V.getName())), 124 /*LowerCase=*/true); 125 } 126 127 /// Main logging helper. 128 /// 129 /// Logging can be configured by the following environment variable. 130 /// AMD_SPLIT_MODULE_LOG_DIR=<filepath> 131 /// If set, uses <filepath> as the directory to write logfiles to 132 /// each time module splitting is used. 133 /// AMD_SPLIT_MODULE_LOG_PRIVATE 134 /// If set to anything other than zero, all names are hidden. 135 /// 136 /// Both environment variables have corresponding CL options which 137 /// takes priority over them. 138 /// 139 /// Any output printed to the log files is also printed to dbgs() when -debug is 140 /// used and LLVM_DEBUG is defined. 141 /// 142 /// This approach has a small disadvantage over LLVM_DEBUG though: logging logic 143 /// cannot be removed from the code (by building without debug). This probably 144 /// has a small performance cost because if some computation/formatting is 145 /// needed for logging purpose, it may be done everytime only to be ignored 146 /// by the logger. 147 /// 148 /// As this pass only runs once and is not doing anything computationally 149 /// expensive, this is likely a reasonable trade-off. 150 /// 151 /// If some computation should really be avoided when unused, users of the class 152 /// can check whether any logging will occur by using the bool operator. 153 /// 154 /// \code 155 /// if (SML) { 156 /// // Executes only if logging to a file or if -debug is available and 157 /// used. 158 /// } 159 /// \endcode 160 class SplitModuleLogger { 161 public: 162 SplitModuleLogger(const Module &M) { 163 std::string LogDir = LogDirOpt; 164 if (LogDir.empty()) 165 LogDir = sys::Process::GetEnv("AMD_SPLIT_MODULE_LOG_DIR").value_or(""); 166 167 // No log dir specified means we don't need to log to a file. 168 // We may still log to dbgs(), though. 169 if (LogDir.empty()) 170 return; 171 172 // If a log directory is specified, create a new file with a unique name in 173 // that directory. 174 int Fd; 175 SmallString<0> PathTemplate; 176 SmallString<0> RealPath; 177 sys::path::append(PathTemplate, LogDir, "Module-%%-%%-%%-%%-%%-%%-%%.txt"); 178 if (auto Err = 179 sys::fs::createUniqueFile(PathTemplate.str(), Fd, RealPath)) { 180 report_fatal_error("Failed to create log file at '" + Twine(LogDir) + 181 "': " + Err.message(), 182 /*CrashDiag=*/false); 183 } 184 185 FileOS = std::make_unique<raw_fd_ostream>(Fd, /*shouldClose=*/true); 186 } 187 188 bool hasLogFile() const { return FileOS != nullptr; } 189 190 raw_ostream &logfile() { 191 assert(FileOS && "no logfile!"); 192 return *FileOS; 193 } 194 195 /// \returns true if this SML will log anything either to a file or dbgs(). 196 /// Can be used to avoid expensive computations that are ignored when logging 197 /// is disabled. 198 operator bool() const { 199 return hasLogFile() || (DebugFlag && isCurrentDebugType(DEBUG_TYPE)); 200 } 201 202 private: 203 std::unique_ptr<raw_fd_ostream> FileOS; 204 }; 205 206 template <typename Ty> 207 static SplitModuleLogger &operator<<(SplitModuleLogger &SML, const Ty &Val) { 208 static_assert( 209 !std::is_same_v<Ty, Value>, 210 "do not print values to logs directly, use handleName instead!"); 211 LLVM_DEBUG(dbgs() << Val); 212 if (SML.hasLogFile()) 213 SML.logfile() << Val; 214 return SML; 215 } 216 217 /// Calculate the cost of each function in \p M 218 /// \param SML Log Helper 219 /// \param GetTTI Abstract getter for TargetTransformInfo. 220 /// \param M Module to analyze. 221 /// \param CostMap[out] Resulting Function -> Cost map. 222 /// \return The module's total cost. 223 static CostType 224 calculateFunctionCosts(SplitModuleLogger &SML, GetTTIFn GetTTI, Module &M, 225 DenseMap<const Function *, CostType> &CostMap) { 226 CostType ModuleCost = 0; 227 CostType KernelCost = 0; 228 229 for (auto &Fn : M) { 230 if (Fn.isDeclaration()) 231 continue; 232 233 CostType FnCost = 0; 234 const auto &TTI = GetTTI(Fn); 235 for (const auto &BB : Fn) { 236 for (const auto &I : BB) { 237 auto Cost = 238 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); 239 assert(Cost != InstructionCost::getMax()); 240 // Assume expensive if we can't tell the cost of an instruction. 241 CostType CostVal = 242 Cost.getValue().value_or(TargetTransformInfo::TCC_Expensive); 243 assert((FnCost + CostVal) >= FnCost && "Overflow!"); 244 FnCost += CostVal; 245 } 246 } 247 248 assert(FnCost != 0); 249 250 CostMap[&Fn] = FnCost; 251 assert((ModuleCost + FnCost) >= ModuleCost && "Overflow!"); 252 ModuleCost += FnCost; 253 254 if (isEntryPoint(&Fn)) 255 KernelCost += FnCost; 256 } 257 258 CostType FnCost = (ModuleCost - KernelCost); 259 CostType ModuleCostOr1 = ModuleCost ? ModuleCost : 1; 260 SML << "=> Total Module Cost: " << ModuleCost << '\n' 261 << " => KernelCost: " << KernelCost << " (" 262 << format("%0.2f", (float(KernelCost) / ModuleCostOr1) * 100) << "%)\n" 263 << " => FnsCost: " << FnCost << " (" 264 << format("%0.2f", (float(FnCost) / ModuleCostOr1) * 100) << "%)\n"; 265 266 return ModuleCost; 267 } 268 269 static bool canBeIndirectlyCalled(const Function &F) { 270 if (F.isDeclaration() || isEntryPoint(&F)) 271 return false; 272 return !F.hasLocalLinkage() || 273 F.hasAddressTaken(/*PutOffender=*/nullptr, 274 /*IgnoreCallbackUses=*/false, 275 /*IgnoreAssumeLikeCalls=*/true, 276 /*IgnoreLLVMUsed=*/true, 277 /*IgnoreARCAttachedCall=*/false, 278 /*IgnoreCastedDirectCall=*/true); 279 } 280 281 /// When a function or any of its callees performs an indirect call, this 282 /// takes over \ref addAllDependencies and adds all potentially callable 283 /// functions to \p Fns so they can be counted as dependencies of the function. 284 /// 285 /// This is needed due to how AMDGPUResourceUsageAnalysis operates: in the 286 /// presence of an indirect call, the function's resource usage is the same as 287 /// the most expensive function in the module. 288 /// \param M The module. 289 /// \param Fns[out] Resulting list of functions. 290 static void addAllIndirectCallDependencies(const Module &M, 291 DenseSet<const Function *> &Fns) { 292 for (const auto &Fn : M) { 293 if (canBeIndirectlyCalled(Fn)) 294 Fns.insert(&Fn); 295 } 296 } 297 298 /// Adds the functions that \p Fn may call to \p Fns, then recurses into each 299 /// callee until all reachable functions have been gathered. 300 /// 301 /// \param SML Log Helper 302 /// \param CG Call graph for \p Fn's module. 303 /// \param Fn Current function to look at. 304 /// \param Fns[out] Resulting list of functions. 305 /// \param OnlyDirect Whether to only consider direct callees. 306 /// \param HadIndirectCall[out] Set to true if an indirect call was seen at some 307 /// point, either in \p Fn or in one of the function it calls. When that 308 /// happens, we fall back to adding all callable functions inside \p Fn's module 309 /// to \p Fns. 310 static void addAllDependencies(SplitModuleLogger &SML, const CallGraph &CG, 311 const Function &Fn, 312 DenseSet<const Function *> &Fns, bool OnlyDirect, 313 bool &HadIndirectCall) { 314 assert(!Fn.isDeclaration()); 315 316 const Module &M = *Fn.getParent(); 317 SmallVector<const Function *> WorkList({&Fn}); 318 while (!WorkList.empty()) { 319 const auto &CurFn = *WorkList.pop_back_val(); 320 assert(!CurFn.isDeclaration()); 321 322 // Scan for an indirect call. If such a call is found, we have to 323 // conservatively assume this can call all non-entrypoint functions in the 324 // module. 325 326 for (auto &CGEntry : *CG[&CurFn]) { 327 auto *CGNode = CGEntry.second; 328 auto *Callee = CGNode->getFunction(); 329 if (!Callee) { 330 if (OnlyDirect) 331 continue; 332 333 // Functions have an edge towards CallsExternalNode if they're external 334 // declarations, or if they do an indirect call. As we only process 335 // definitions here, we know this means the function has an indirect 336 // call. We then have to conservatively assume this can call all 337 // non-entrypoint functions in the module. 338 if (CGNode != CG.getCallsExternalNode()) 339 continue; // this is another function-less node we don't care about. 340 341 SML << "Indirect call detected in " << getName(CurFn) 342 << " - treating all non-entrypoint functions as " 343 "potential dependencies\n"; 344 345 // TODO: Print an ORE as well ? 346 addAllIndirectCallDependencies(M, Fns); 347 HadIndirectCall = true; 348 continue; 349 } 350 351 if (Callee->isDeclaration()) 352 continue; 353 354 auto [It, Inserted] = Fns.insert(Callee); 355 if (Inserted) 356 WorkList.push_back(Callee); 357 } 358 } 359 } 360 361 /// Contains information about a function and its dependencies. 362 /// This is a splitting root. The splitting algorithm works by 363 /// assigning these to partitions. 364 struct FunctionWithDependencies { 365 FunctionWithDependencies(SplitModuleLogger &SML, CallGraph &CG, 366 const DenseMap<const Function *, CostType> &FnCosts, 367 const Function *Fn) 368 : Fn(Fn) { 369 // When Fn is not a kernel, we don't need to collect indirect callees. 370 // Resource usage analysis is only performed on kernels, and we collect 371 // indirect callees for resource usage analysis. 372 addAllDependencies(SML, CG, *Fn, Dependencies, 373 /*OnlyDirect*/ !isEntryPoint(Fn), HasIndirectCall); 374 TotalCost = FnCosts.at(Fn); 375 for (const auto *Dep : Dependencies) { 376 TotalCost += FnCosts.at(Dep); 377 378 // We cannot duplicate functions with external linkage, or functions that 379 // may be overriden at runtime. 380 HasNonDuplicatableDependecy |= 381 (Dep->hasExternalLinkage() || !Dep->isDefinitionExact()); 382 } 383 } 384 385 const Function *Fn = nullptr; 386 DenseSet<const Function *> Dependencies; 387 /// Whether \p Fn or any of its \ref Dependencies contains an indirect call. 388 bool HasIndirectCall = false; 389 /// Whether any of \p Fn's dependencies cannot be duplicated. 390 bool HasNonDuplicatableDependecy = false; 391 392 CostType TotalCost = 0; 393 394 /// \returns true if this function and its dependencies can be considered 395 /// large according to \p Threshold. 396 bool isLarge(CostType Threshold) const { 397 return TotalCost > Threshold && !Dependencies.empty(); 398 } 399 }; 400 401 /// Calculates how much overlap there is between \p A and \p B. 402 /// \return A number between 0.0 and 1.0, where 1.0 means A == B and 0.0 means A 403 /// and B have no shared elements. Kernels do not count in overlap calculation. 404 static float calculateOverlap(const DenseSet<const Function *> &A, 405 const DenseSet<const Function *> &B) { 406 DenseSet<const Function *> Total; 407 for (const auto *F : A) { 408 if (!isEntryPoint(F)) 409 Total.insert(F); 410 } 411 412 if (Total.empty()) 413 return 0.0f; 414 415 unsigned NumCommon = 0; 416 for (const auto *F : B) { 417 if (isEntryPoint(F)) 418 continue; 419 420 auto [It, Inserted] = Total.insert(F); 421 if (!Inserted) 422 ++NumCommon; 423 } 424 425 return static_cast<float>(NumCommon) / Total.size(); 426 } 427 428 /// Performs all of the partitioning work on \p M. 429 /// \param SML Log Helper 430 /// \param M Module to partition. 431 /// \param NumParts Number of partitions to create. 432 /// \param ModuleCost Total cost of all functions in \p M. 433 /// \param FnCosts Map of Function -> Cost 434 /// \param WorkList Functions and their dependencies to process in order. 435 /// \returns The created partitions (a vector of size \p NumParts ) 436 static std::vector<DenseSet<const Function *>> 437 doPartitioning(SplitModuleLogger &SML, Module &M, unsigned NumParts, 438 CostType ModuleCost, 439 const DenseMap<const Function *, CostType> &FnCosts, 440 const SmallVector<FunctionWithDependencies> &WorkList) { 441 442 SML << "\n--Partitioning Starts--\n"; 443 444 // Calculate a "large function threshold". When more than one function's total 445 // import cost exceeds this value, we will try to assign it to an existing 446 // partition to reduce the amount of duplication needed. 447 // 448 // e.g. let two functions X and Y have a import cost of ~10% of the module, we 449 // assign X to a partition as usual, but when we get to Y, we check if it's 450 // worth also putting it in Y's partition. 451 const CostType LargeFnThreshold = 452 LargeFnFactor ? CostType(((ModuleCost / NumParts) * LargeFnFactor)) 453 : std::numeric_limits<CostType>::max(); 454 455 std::vector<DenseSet<const Function *>> Partitions; 456 Partitions.resize(NumParts); 457 458 // Assign functions to partitions, and try to keep the partitions more or 459 // less balanced. We do that through a priority queue sorted in reverse, so we 460 // can always look at the partition with the least content. 461 // 462 // There are some cases where we will be deliberately unbalanced though. 463 // - Large functions: we try to merge with existing partitions to reduce code 464 // duplication. 465 // - Functions with indirect or external calls always go in the first 466 // partition (P0). 467 auto ComparePartitions = [](const std::pair<PartitionID, CostType> &a, 468 const std::pair<PartitionID, CostType> &b) { 469 // When two partitions have the same cost, assign to the one with the 470 // biggest ID first. This allows us to put things in P0 last, because P0 may 471 // have other stuff added later. 472 if (a.second == b.second) 473 return a.first < b.first; 474 return a.second > b.second; 475 }; 476 477 // We can't use priority_queue here because we need to be able to access any 478 // element. This makes this a bit inefficient as we need to sort it again 479 // everytime we change it, but it's a very small array anyway (likely under 64 480 // partitions) so it's a cheap operation. 481 std::vector<std::pair<PartitionID, CostType>> BalancingQueue; 482 for (unsigned I = 0; I < NumParts; ++I) 483 BalancingQueue.emplace_back(I, 0); 484 485 // Helper function to handle assigning a function to a partition. This takes 486 // care of updating the balancing queue. 487 const auto AssignToPartition = [&](PartitionID PID, 488 const FunctionWithDependencies &FWD) { 489 auto &FnsInPart = Partitions[PID]; 490 FnsInPart.insert(FWD.Fn); 491 FnsInPart.insert(FWD.Dependencies.begin(), FWD.Dependencies.end()); 492 493 SML << "assign " << getName(*FWD.Fn) << " to P" << PID << "\n -> "; 494 if (!FWD.Dependencies.empty()) { 495 SML << FWD.Dependencies.size() << " dependencies added\n"; 496 }; 497 498 // Update the balancing queue. we scan backwards because in the common case 499 // the partition is at the end. 500 for (auto &[QueuePID, Cost] : reverse(BalancingQueue)) { 501 if (QueuePID == PID) { 502 CostType NewCost = 0; 503 for (auto *Fn : Partitions[PID]) 504 NewCost += FnCosts.at(Fn); 505 506 SML << "[Updating P" << PID << " Cost]:" << Cost << " -> " << NewCost; 507 if (Cost) { 508 SML << " (" << unsigned(((float(NewCost) / Cost) - 1) * 100) 509 << "% increase)"; 510 } 511 SML << '\n'; 512 513 Cost = NewCost; 514 } 515 } 516 517 sort(BalancingQueue, ComparePartitions); 518 }; 519 520 for (auto &CurFn : WorkList) { 521 // When a function has indirect calls, it must stay in the first partition 522 // alongside every reachable non-entry function. This is a nightmare case 523 // for splitting as it severely limits what we can do. 524 if (CurFn.HasIndirectCall) { 525 SML << "Function with indirect call(s): " << getName(*CurFn.Fn) 526 << " defaulting to P0\n"; 527 AssignToPartition(0, CurFn); 528 continue; 529 } 530 531 // When a function has non duplicatable dependencies, we have to keep it in 532 // the first partition as well. This is a conservative approach, a 533 // finer-grained approach could keep track of which dependencies are 534 // non-duplicatable exactly and just make sure they're grouped together. 535 if (CurFn.HasNonDuplicatableDependecy) { 536 SML << "Function with externally visible dependency " 537 << getName(*CurFn.Fn) << " defaulting to P0\n"; 538 AssignToPartition(0, CurFn); 539 continue; 540 } 541 542 // Be smart with large functions to avoid duplicating their dependencies. 543 if (CurFn.isLarge(LargeFnThreshold)) { 544 assert(LargeFnOverlapForMerge >= 0.0f && LargeFnOverlapForMerge <= 1.0f); 545 SML << "Large Function: " << getName(*CurFn.Fn) 546 << " - looking for partition with at least " 547 << format("%0.2f", LargeFnOverlapForMerge * 100) << "% overlap\n"; 548 549 bool Assigned = false; 550 for (const auto &[PID, Fns] : enumerate(Partitions)) { 551 float Overlap = calculateOverlap(CurFn.Dependencies, Fns); 552 SML << " => " << format("%0.2f", Overlap * 100) << "% overlap with P" 553 << PID << '\n'; 554 if (Overlap > LargeFnOverlapForMerge) { 555 SML << " selecting P" << PID << '\n'; 556 AssignToPartition(PID, CurFn); 557 Assigned = true; 558 } 559 } 560 561 if (Assigned) 562 continue; 563 } 564 565 // Normal "load-balancing", assign to partition with least pressure. 566 auto [PID, CurCost] = BalancingQueue.back(); 567 AssignToPartition(PID, CurFn); 568 } 569 570 if (SML) { 571 for (const auto &[Idx, Part] : enumerate(Partitions)) { 572 CostType Cost = 0; 573 for (auto *Fn : Part) 574 Cost += FnCosts.at(Fn); 575 SML << "P" << Idx << " has a total cost of " << Cost << " (" 576 << format("%0.2f", (float(Cost) / ModuleCost) * 100) 577 << "% of source module)\n"; 578 } 579 580 SML << "--Partitioning Done--\n\n"; 581 } 582 583 // Check no functions were missed. 584 #ifndef NDEBUG 585 DenseSet<const Function *> AllFunctions; 586 for (const auto &Part : Partitions) 587 AllFunctions.insert(Part.begin(), Part.end()); 588 589 for (auto &Fn : M) { 590 if (!Fn.isDeclaration() && !AllFunctions.contains(&Fn)) { 591 assert(AllFunctions.contains(&Fn) && "Missed a function?!"); 592 } 593 } 594 #endif 595 596 return Partitions; 597 } 598 599 static void externalize(GlobalValue &GV) { 600 if (GV.hasLocalLinkage()) { 601 GV.setLinkage(GlobalValue::ExternalLinkage); 602 GV.setVisibility(GlobalValue::HiddenVisibility); 603 } 604 605 // Unnamed entities must be named consistently between modules. setName will 606 // give a distinct name to each such entity. 607 if (!GV.hasName()) 608 GV.setName("__llvmsplit_unnamed"); 609 } 610 611 static bool hasDirectCaller(const Function &Fn) { 612 for (auto &U : Fn.uses()) { 613 if (auto *CB = dyn_cast<CallBase>(U.getUser()); CB && CB->isCallee(&U)) 614 return true; 615 } 616 return false; 617 } 618 619 static void splitAMDGPUModule( 620 GetTTIFn GetTTI, Module &M, unsigned N, 621 function_ref<void(std::unique_ptr<Module> MPart)> ModuleCallback) { 622 623 SplitModuleLogger SML(M); 624 625 CallGraph CG(M); 626 627 // Externalize functions whose address are taken. 628 // 629 // This is needed because partitioning is purely based on calls, but sometimes 630 // a kernel/function may just look at the address of another local function 631 // and not do anything (no calls). After partitioning, that local function may 632 // end up in a different module (so it's just a declaration in the module 633 // where its address is taken), which emits a "undefined hidden symbol" linker 634 // error. 635 // 636 // Additionally, it guides partitioning to not duplicate this function if it's 637 // called directly at some point. 638 for (auto &Fn : M) { 639 if (Fn.hasAddressTaken()) { 640 if (Fn.hasLocalLinkage()) { 641 SML << "[externalize] " << Fn.getName() 642 << " because its address is taken\n"; 643 } 644 externalize(Fn); 645 } 646 } 647 648 // Externalize local GVs, which avoids duplicating their initializers, which 649 // in turns helps keep code size in check. 650 if (!NoExternalizeGlobals) { 651 for (auto &GV : M.globals()) { 652 if (GV.hasLocalLinkage()) 653 SML << "[externalize] GV " << GV.getName() << '\n'; 654 externalize(GV); 655 } 656 } 657 658 // Start by calculating the cost of every function in the module, as well as 659 // the module's overall cost. 660 DenseMap<const Function *, CostType> FnCosts; 661 const CostType ModuleCost = calculateFunctionCosts(SML, GetTTI, M, FnCosts); 662 663 // First, gather ever kernel into the worklist. 664 SmallVector<FunctionWithDependencies> WorkList; 665 for (auto &Fn : M) { 666 if (isEntryPoint(&Fn) && !Fn.isDeclaration()) 667 WorkList.emplace_back(SML, CG, FnCosts, &Fn); 668 } 669 670 // Then, find missing functions that need to be considered as additional 671 // roots. These can't be called in theory, but in practice we still have to 672 // handle them to avoid linker errors. 673 { 674 DenseSet<const Function *> SeenFunctions; 675 for (const auto &FWD : WorkList) { 676 SeenFunctions.insert(FWD.Fn); 677 SeenFunctions.insert(FWD.Dependencies.begin(), FWD.Dependencies.end()); 678 } 679 680 for (auto &Fn : M) { 681 // If this function is not part of any kernel's dependencies and isn't 682 // directly called, consider it as a root. 683 if (!Fn.isDeclaration() && !isEntryPoint(&Fn) && 684 !SeenFunctions.count(&Fn) && !hasDirectCaller(Fn)) { 685 WorkList.emplace_back(SML, CG, FnCosts, &Fn); 686 } 687 } 688 } 689 690 // Sort the worklist so the most expensive roots are seen first. 691 sort(WorkList, [&](auto &A, auto &B) { 692 // Sort by total cost, and if the total cost is identical, sort 693 // alphabetically. 694 if (A.TotalCost == B.TotalCost) 695 return A.Fn->getName() < B.Fn->getName(); 696 return A.TotalCost > B.TotalCost; 697 }); 698 699 if (SML) { 700 SML << "Worklist\n"; 701 for (const auto &FWD : WorkList) { 702 SML << "[root] " << getName(*FWD.Fn) << " (totalCost:" << FWD.TotalCost 703 << " indirect:" << FWD.HasIndirectCall 704 << " hasNonDuplicatableDep:" << FWD.HasNonDuplicatableDependecy 705 << ")\n"; 706 // Sort function names before printing to ensure determinism. 707 SmallVector<std::string> SortedDepNames; 708 SortedDepNames.reserve(FWD.Dependencies.size()); 709 for (const auto *Dep : FWD.Dependencies) 710 SortedDepNames.push_back(getName(*Dep)); 711 sort(SortedDepNames); 712 713 for (const auto &Name : SortedDepNames) 714 SML << " [dependency] " << Name << '\n'; 715 } 716 } 717 718 // This performs all of the partitioning work. 719 auto Partitions = doPartitioning(SML, M, N, ModuleCost, FnCosts, WorkList); 720 assert(Partitions.size() == N); 721 722 // If we didn't externalize GVs, then local GVs need to be conservatively 723 // imported into every module (including their initializers), and then cleaned 724 // up afterwards. 725 const auto NeedsConservativeImport = [&](const GlobalValue *GV) { 726 // We conservatively import private/internal GVs into every module and clean 727 // them up afterwards. 728 const auto *Var = dyn_cast<GlobalVariable>(GV); 729 return Var && Var->hasLocalLinkage(); 730 }; 731 732 SML << "Creating " << N << " modules...\n"; 733 unsigned TotalFnImpls = 0; 734 for (unsigned I = 0; I < N; ++I) { 735 const auto &FnsInPart = Partitions[I]; 736 737 ValueToValueMapTy VMap; 738 std::unique_ptr<Module> MPart( 739 CloneModule(M, VMap, [&](const GlobalValue *GV) { 740 // Functions go in their assigned partition. 741 if (const auto *Fn = dyn_cast<Function>(GV)) 742 return FnsInPart.contains(Fn); 743 744 if (NeedsConservativeImport(GV)) 745 return true; 746 747 // Everything else goes in the first partition. 748 return I == 0; 749 })); 750 751 // Clean-up conservatively imported GVs without any users. 752 for (auto &GV : make_early_inc_range(MPart->globals())) { 753 if (NeedsConservativeImport(&GV) && GV.use_empty()) 754 GV.eraseFromParent(); 755 } 756 757 unsigned NumAllFns = 0, NumKernels = 0; 758 for (auto &Cur : *MPart) { 759 if (!Cur.isDeclaration()) { 760 ++NumAllFns; 761 if (isEntryPoint(&Cur)) 762 ++NumKernels; 763 } 764 } 765 TotalFnImpls += NumAllFns; 766 SML << " - Module " << I << " with " << NumAllFns << " functions (" 767 << NumKernels << " kernels)\n"; 768 ModuleCallback(std::move(MPart)); 769 } 770 771 SML << TotalFnImpls << " function definitions across all modules (" 772 << format("%0.2f", (float(TotalFnImpls) / FnCosts.size()) * 100) 773 << "% of original module)\n"; 774 } 775 } // namespace 776 777 PreservedAnalyses AMDGPUSplitModulePass::run(Module &M, 778 ModuleAnalysisManager &MAM) { 779 FunctionAnalysisManager &FAM = 780 MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 781 const auto TTIGetter = [&FAM](Function &F) -> const TargetTransformInfo & { 782 return FAM.getResult<TargetIRAnalysis>(F); 783 }; 784 splitAMDGPUModule(TTIGetter, M, N, ModuleCallback); 785 // We don't change the original module. 786 return PreservedAnalyses::all(); 787 } 788