1 //===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file contains the IR transform to lower external or indirect calls for 11 /// the ARM64EC calling convention. Such calls must go through the runtime, so 12 /// we can translate the calling convention for calls into the emulator. 13 /// 14 /// This subsumes Control Flow Guard handling. 15 /// 16 //===----------------------------------------------------------------------===// 17 18 #include "AArch64.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallString.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/ADT/Statistic.h" 23 #include "llvm/IR/CallingConv.h" 24 #include "llvm/IR/IRBuilder.h" 25 #include "llvm/IR/Instruction.h" 26 #include "llvm/InitializePasses.h" 27 #include "llvm/Object/COFF.h" 28 #include "llvm/Pass.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/TargetParser/Triple.h" 31 32 using namespace llvm; 33 using namespace llvm::object; 34 35 using OperandBundleDef = OperandBundleDefT<Value *>; 36 37 #define DEBUG_TYPE "arm64eccalllowering" 38 39 STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered"); 40 41 static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect", 42 cl::Hidden, cl::init(true)); 43 static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden, 44 cl::init(true)); 45 46 namespace { 47 48 enum class ThunkType { GuestExit, Entry, Exit }; 49 50 class AArch64Arm64ECCallLowering : public ModulePass { 51 public: 52 static char ID; 53 AArch64Arm64ECCallLowering() : ModulePass(ID) { 54 initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry()); 55 } 56 57 Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs); 58 Function *buildEntryThunk(Function *F); 59 void lowerCall(CallBase *CB); 60 Function *buildGuestExitThunk(Function *F); 61 bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns); 62 bool runOnModule(Module &M) override; 63 64 private: 65 int cfguard_module_flag = 0; 66 FunctionType *GuardFnType = nullptr; 67 PointerType *GuardFnPtrType = nullptr; 68 Constant *GuardFnCFGlobal = nullptr; 69 Constant *GuardFnGlobal = nullptr; 70 Module *M = nullptr; 71 72 Type *PtrTy; 73 Type *I64Ty; 74 Type *VoidTy; 75 76 void getThunkType(FunctionType *FT, AttributeList AttrList, ThunkType TT, 77 raw_ostream &Out, FunctionType *&Arm64Ty, 78 FunctionType *&X64Ty); 79 void getThunkRetType(FunctionType *FT, AttributeList AttrList, 80 raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy, 81 SmallVectorImpl<Type *> &Arm64ArgTypes, 82 SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr); 83 void getThunkArgTypes(FunctionType *FT, AttributeList AttrList, ThunkType TT, 84 raw_ostream &Out, 85 SmallVectorImpl<Type *> &Arm64ArgTypes, 86 SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr); 87 void canonicalizeThunkType(Type *T, Align Alignment, bool Ret, 88 uint64_t ArgSizeBytes, raw_ostream &Out, 89 Type *&Arm64Ty, Type *&X64Ty); 90 }; 91 92 } // end anonymous namespace 93 94 void AArch64Arm64ECCallLowering::getThunkType(FunctionType *FT, 95 AttributeList AttrList, 96 ThunkType TT, raw_ostream &Out, 97 FunctionType *&Arm64Ty, 98 FunctionType *&X64Ty) { 99 Out << (TT == ThunkType::Entry ? "$ientry_thunk$cdecl$" 100 : "$iexit_thunk$cdecl$"); 101 102 Type *Arm64RetTy; 103 Type *X64RetTy; 104 105 SmallVector<Type *> Arm64ArgTypes; 106 SmallVector<Type *> X64ArgTypes; 107 108 // The first argument to a thunk is the called function, stored in x9. 109 // For exit thunks, we pass the called function down to the emulator; 110 // for entry/guest exit thunks, we just call the Arm64 function directly. 111 if (TT == ThunkType::Exit) 112 Arm64ArgTypes.push_back(PtrTy); 113 X64ArgTypes.push_back(PtrTy); 114 115 bool HasSretPtr = false; 116 getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes, 117 X64ArgTypes, HasSretPtr); 118 119 getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes, 120 HasSretPtr); 121 122 Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false); 123 124 X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false); 125 } 126 127 void AArch64Arm64ECCallLowering::getThunkArgTypes( 128 FunctionType *FT, AttributeList AttrList, ThunkType TT, raw_ostream &Out, 129 SmallVectorImpl<Type *> &Arm64ArgTypes, 130 SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr) { 131 132 Out << "$"; 133 if (FT->isVarArg()) { 134 // We treat the variadic function's thunk as a normal function 135 // with the following type on the ARM side: 136 // rettype exitthunk( 137 // ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5) 138 // 139 // that can coverage all types of variadic function. 140 // x9 is similar to normal exit thunk, store the called function. 141 // x0-x3 is the arguments be stored in registers. 142 // x4 is the address of the arguments on the stack. 143 // x5 is the size of the arguments on the stack. 144 // 145 // On the x64 side, it's the same except that x5 isn't set. 146 // 147 // If both the ARM and X64 sides are sret, there are only three 148 // arguments in registers. 149 // 150 // If the X64 side is sret, but the ARM side isn't, we pass an extra value 151 // to/from the X64 side, and let SelectionDAG transform it into a memory 152 // location. 153 Out << "varargs"; 154 155 // x0-x3 156 for (int i = HasSretPtr ? 1 : 0; i < 4; i++) { 157 Arm64ArgTypes.push_back(I64Ty); 158 X64ArgTypes.push_back(I64Ty); 159 } 160 161 // x4 162 Arm64ArgTypes.push_back(PtrTy); 163 X64ArgTypes.push_back(PtrTy); 164 // x5 165 Arm64ArgTypes.push_back(I64Ty); 166 if (TT != ThunkType::Entry) { 167 // FIXME: x5 isn't actually used by the x64 side; revisit once we 168 // have proper isel for varargs 169 X64ArgTypes.push_back(I64Ty); 170 } 171 return; 172 } 173 174 unsigned I = 0; 175 if (HasSretPtr) 176 I++; 177 178 if (I == FT->getNumParams()) { 179 Out << "v"; 180 return; 181 } 182 183 for (unsigned E = FT->getNumParams(); I != E; ++I) { 184 Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne(); 185 #if 0 186 // FIXME: Need more information about argument size; see 187 // https://reviews.llvm.org/D132926 188 uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I); 189 #else 190 uint64_t ArgSizeBytes = 0; 191 #endif 192 Type *Arm64Ty, *X64Ty; 193 canonicalizeThunkType(FT->getParamType(I), ParamAlign, 194 /*Ret*/ false, ArgSizeBytes, Out, Arm64Ty, X64Ty); 195 Arm64ArgTypes.push_back(Arm64Ty); 196 X64ArgTypes.push_back(X64Ty); 197 } 198 } 199 200 void AArch64Arm64ECCallLowering::getThunkRetType( 201 FunctionType *FT, AttributeList AttrList, raw_ostream &Out, 202 Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes, 203 SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr) { 204 Type *T = FT->getReturnType(); 205 #if 0 206 // FIXME: Need more information about argument size; see 207 // https://reviews.llvm.org/D132926 208 uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes(); 209 #else 210 int64_t ArgSizeBytes = 0; 211 #endif 212 if (T->isVoidTy()) { 213 if (FT->getNumParams()) { 214 auto SRetAttr = AttrList.getParamAttr(0, Attribute::StructRet); 215 auto InRegAttr = AttrList.getParamAttr(0, Attribute::InReg); 216 if (SRetAttr.isValid() && InRegAttr.isValid()) { 217 // sret+inreg indicates a call that returns a C++ class value. This is 218 // actually equivalent to just passing and returning a void* pointer 219 // as the first argument. Translate it that way, instead of trying 220 // to model "inreg" in the thunk's calling convention, to simplify 221 // the rest of the code. 222 Out << "i8"; 223 Arm64RetTy = I64Ty; 224 X64RetTy = I64Ty; 225 return; 226 } 227 if (SRetAttr.isValid()) { 228 // FIXME: Sanity-check the sret type; if it's an integer or pointer, 229 // we'll get screwy mangling/codegen. 230 // FIXME: For large struct types, mangle as an integer argument and 231 // integer return, so we can reuse more thunks, instead of "m" syntax. 232 // (MSVC mangles this case as an integer return with no argument, but 233 // that's a miscompile.) 234 Type *SRetType = SRetAttr.getValueAsType(); 235 Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne(); 236 Type *Arm64Ty, *X64Ty; 237 canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes, 238 Out, Arm64Ty, X64Ty); 239 Arm64RetTy = VoidTy; 240 X64RetTy = VoidTy; 241 Arm64ArgTypes.push_back(FT->getParamType(0)); 242 X64ArgTypes.push_back(FT->getParamType(0)); 243 HasSretPtr = true; 244 return; 245 } 246 } 247 248 Out << "v"; 249 Arm64RetTy = VoidTy; 250 X64RetTy = VoidTy; 251 return; 252 } 253 254 canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out, Arm64RetTy, 255 X64RetTy); 256 if (X64RetTy->isPointerTy()) { 257 // If the X64 type is canonicalized to a pointer, that means it's 258 // passed/returned indirectly. For a return value, that means it's an 259 // sret pointer. 260 X64ArgTypes.push_back(X64RetTy); 261 X64RetTy = VoidTy; 262 } 263 } 264 265 void AArch64Arm64ECCallLowering::canonicalizeThunkType( 266 Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, raw_ostream &Out, 267 Type *&Arm64Ty, Type *&X64Ty) { 268 if (T->isFloatTy()) { 269 Out << "f"; 270 Arm64Ty = T; 271 X64Ty = T; 272 return; 273 } 274 275 if (T->isDoubleTy()) { 276 Out << "d"; 277 Arm64Ty = T; 278 X64Ty = T; 279 return; 280 } 281 282 if (T->isFloatingPointTy()) { 283 report_fatal_error( 284 "Only 32 and 64 bit floating points are supported for ARM64EC thunks"); 285 } 286 287 auto &DL = M->getDataLayout(); 288 289 if (auto *StructTy = dyn_cast<StructType>(T)) 290 if (StructTy->getNumElements() == 1) 291 T = StructTy->getElementType(0); 292 293 if (T->isArrayTy()) { 294 Type *ElementTy = T->getArrayElementType(); 295 uint64_t ElementCnt = T->getArrayNumElements(); 296 uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8; 297 uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes; 298 if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) { 299 Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes; 300 if (Alignment.value() >= 8 && !T->isPointerTy()) 301 Out << "a" << Alignment.value(); 302 Arm64Ty = T; 303 if (TotalSizeBytes <= 8) { 304 // Arm64 returns small structs of float/double in float registers; 305 // X64 uses RAX. 306 X64Ty = llvm::Type::getIntNTy(M->getContext(), TotalSizeBytes * 8); 307 } else { 308 // Struct is passed directly on Arm64, but indirectly on X64. 309 X64Ty = PtrTy; 310 } 311 return; 312 } else if (T->isFloatingPointTy()) { 313 report_fatal_error("Only 32 and 64 bit floating points are supported for " 314 "ARM64EC thunks"); 315 } 316 } 317 318 if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) { 319 Out << "i8"; 320 Arm64Ty = I64Ty; 321 X64Ty = I64Ty; 322 return; 323 } 324 325 unsigned TypeSize = ArgSizeBytes; 326 if (TypeSize == 0) 327 TypeSize = DL.getTypeSizeInBits(T) / 8; 328 Out << "m"; 329 if (TypeSize != 4) 330 Out << TypeSize; 331 if (Alignment.value() >= 8 && !T->isPointerTy()) 332 Out << "a" << Alignment.value(); 333 // FIXME: Try to canonicalize Arm64Ty more thoroughly? 334 Arm64Ty = T; 335 if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) { 336 // Pass directly in an integer register 337 X64Ty = llvm::Type::getIntNTy(M->getContext(), TypeSize * 8); 338 } else { 339 // Passed directly on Arm64, but indirectly on X64. 340 X64Ty = PtrTy; 341 } 342 } 343 344 // This function builds the "exit thunk", a function which translates 345 // arguments and return values when calling x64 code from AArch64 code. 346 Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT, 347 AttributeList Attrs) { 348 SmallString<256> ExitThunkName; 349 llvm::raw_svector_ostream ExitThunkStream(ExitThunkName); 350 FunctionType *Arm64Ty, *X64Ty; 351 getThunkType(FT, Attrs, ThunkType::Exit, ExitThunkStream, Arm64Ty, X64Ty); 352 if (Function *F = M->getFunction(ExitThunkName)) 353 return F; 354 355 Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0, 356 ExitThunkName, M); 357 F->setCallingConv(CallingConv::ARM64EC_Thunk_Native); 358 F->setSection(".wowthk$aa"); 359 F->setComdat(M->getOrInsertComdat(ExitThunkName)); 360 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.) 361 F->addFnAttr("frame-pointer", "all"); 362 // Only copy sret from the first argument. For C++ instance methods, clang can 363 // stick an sret marking on a later argument, but it doesn't actually affect 364 // the ABI, so we can omit it. This avoids triggering a verifier assertion. 365 if (FT->getNumParams()) { 366 auto SRet = Attrs.getParamAttr(0, Attribute::StructRet); 367 auto InReg = Attrs.getParamAttr(0, Attribute::InReg); 368 if (SRet.isValid() && !InReg.isValid()) 369 F->addParamAttr(1, SRet); 370 } 371 // FIXME: Copy anything other than sret? Shouldn't be necessary for normal 372 // C ABI, but might show up in other cases. 373 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F); 374 IRBuilder<> IRB(BB); 375 Value *CalleePtr = 376 M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy); 377 Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr); 378 auto &DL = M->getDataLayout(); 379 SmallVector<Value *> Args; 380 381 // Pass the called function in x9. 382 Args.push_back(F->arg_begin()); 383 384 Type *RetTy = Arm64Ty->getReturnType(); 385 if (RetTy != X64Ty->getReturnType()) { 386 // If the return type is an array or struct, translate it. Values of size 387 // 8 or less go into RAX; bigger values go into memory, and we pass a 388 // pointer. 389 if (DL.getTypeStoreSize(RetTy) > 8) { 390 Args.push_back(IRB.CreateAlloca(RetTy)); 391 } 392 } 393 394 for (auto &Arg : make_range(F->arg_begin() + 1, F->arg_end())) { 395 // Translate arguments from AArch64 calling convention to x86 calling 396 // convention. 397 // 398 // For simple types, we don't need to do any translation: they're 399 // represented the same way. (Implicit sign extension is not part of 400 // either convention.) 401 // 402 // The big thing we have to worry about is struct types... but 403 // fortunately AArch64 clang is pretty friendly here: the cases that need 404 // translation are always passed as a struct or array. (If we run into 405 // some cases where this doesn't work, we can teach clang to mark it up 406 // with an attribute.) 407 // 408 // The first argument is the called function, stored in x9. 409 if (Arg.getType()->isArrayTy() || Arg.getType()->isStructTy() || 410 DL.getTypeStoreSize(Arg.getType()) > 8) { 411 Value *Mem = IRB.CreateAlloca(Arg.getType()); 412 IRB.CreateStore(&Arg, Mem); 413 if (DL.getTypeStoreSize(Arg.getType()) <= 8) { 414 Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType())); 415 Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy))); 416 } else 417 Args.push_back(Mem); 418 } else { 419 Args.push_back(&Arg); 420 } 421 } 422 // FIXME: Transfer necessary attributes? sret? anything else? 423 424 Callee = IRB.CreateBitCast(Callee, PtrTy); 425 CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args); 426 Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64); 427 428 Value *RetVal = Call; 429 if (RetTy != X64Ty->getReturnType()) { 430 // If we rewrote the return type earlier, convert the return value to 431 // the proper type. 432 if (DL.getTypeStoreSize(RetTy) > 8) { 433 RetVal = IRB.CreateLoad(RetTy, Args[1]); 434 } else { 435 Value *CastAlloca = IRB.CreateAlloca(RetTy); 436 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy)); 437 RetVal = IRB.CreateLoad(RetTy, CastAlloca); 438 } 439 } 440 441 if (RetTy->isVoidTy()) 442 IRB.CreateRetVoid(); 443 else 444 IRB.CreateRet(RetVal); 445 return F; 446 } 447 448 // This function builds the "entry thunk", a function which translates 449 // arguments and return values when calling AArch64 code from x64 code. 450 Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) { 451 SmallString<256> EntryThunkName; 452 llvm::raw_svector_ostream EntryThunkStream(EntryThunkName); 453 FunctionType *Arm64Ty, *X64Ty; 454 getThunkType(F->getFunctionType(), F->getAttributes(), ThunkType::Entry, 455 EntryThunkStream, Arm64Ty, X64Ty); 456 if (Function *F = M->getFunction(EntryThunkName)) 457 return F; 458 459 Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0, 460 EntryThunkName, M); 461 Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64); 462 Thunk->setSection(".wowthk$aa"); 463 Thunk->setComdat(M->getOrInsertComdat(EntryThunkName)); 464 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.) 465 Thunk->addFnAttr("frame-pointer", "all"); 466 467 auto &DL = M->getDataLayout(); 468 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk); 469 IRBuilder<> IRB(BB); 470 471 Type *RetTy = Arm64Ty->getReturnType(); 472 Type *X64RetType = X64Ty->getReturnType(); 473 474 bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy(); 475 unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1; 476 unsigned PassthroughArgSize = F->isVarArg() ? 5 : Thunk->arg_size(); 477 478 // Translate arguments to call. 479 SmallVector<Value *> Args; 480 for (unsigned i = ThunkArgOffset, e = PassthroughArgSize; i != e; ++i) { 481 Value *Arg = Thunk->getArg(i); 482 Type *ArgTy = Arm64Ty->getParamType(i - ThunkArgOffset); 483 if (ArgTy->isArrayTy() || ArgTy->isStructTy() || 484 DL.getTypeStoreSize(ArgTy) > 8) { 485 // Translate array/struct arguments to the expected type. 486 if (DL.getTypeStoreSize(ArgTy) <= 8) { 487 Value *CastAlloca = IRB.CreateAlloca(ArgTy); 488 IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy)); 489 Arg = IRB.CreateLoad(ArgTy, CastAlloca); 490 } else { 491 Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy)); 492 } 493 } 494 Args.push_back(Arg); 495 } 496 497 if (F->isVarArg()) { 498 // The 5th argument to variadic entry thunks is used to model the x64 sp 499 // which is passed to the thunk in x4, this can be passed to the callee as 500 // the variadic argument start address after skipping over the 32 byte 501 // shadow store. 502 503 // The EC thunk CC will assign any argument marked as InReg to x4. 504 Thunk->addParamAttr(5, Attribute::InReg); 505 Value *Arg = Thunk->getArg(5); 506 Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20)); 507 Args.push_back(Arg); 508 509 // Pass in a zero variadic argument size (in x5). 510 Args.push_back(IRB.getInt64(0)); 511 } 512 513 // Call the function passed to the thunk. 514 Value *Callee = Thunk->getArg(0); 515 Callee = IRB.CreateBitCast(Callee, PtrTy); 516 Value *Call = IRB.CreateCall(Arm64Ty, Callee, Args); 517 518 Value *RetVal = Call; 519 if (TransformDirectToSRet) { 520 IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy)); 521 } else if (X64RetType != RetTy) { 522 Value *CastAlloca = IRB.CreateAlloca(X64RetType); 523 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy)); 524 RetVal = IRB.CreateLoad(X64RetType, CastAlloca); 525 } 526 527 // Return to the caller. Note that the isel has code to translate this 528 // "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we 529 // could emit a tail call here, but that would require a dedicated calling 530 // convention, which seems more complicated overall.) 531 if (X64RetType->isVoidTy()) 532 IRB.CreateRetVoid(); 533 else 534 IRB.CreateRet(RetVal); 535 536 return Thunk; 537 } 538 539 // Builds the "guest exit thunk", a helper to call a function which may or may 540 // not be an exit thunk. (We optimistically assume non-dllimport function 541 // declarations refer to functions defined in AArch64 code; if the linker 542 // can't prove that, we use this routine instead.) 543 Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) { 544 llvm::raw_null_ostream NullThunkName; 545 FunctionType *Arm64Ty, *X64Ty; 546 getThunkType(F->getFunctionType(), F->getAttributes(), ThunkType::GuestExit, 547 NullThunkName, Arm64Ty, X64Ty); 548 auto MangledName = getArm64ECMangledFunctionName(F->getName().str()); 549 assert(MangledName && "Can't guest exit to function that's already native"); 550 std::string ThunkName = *MangledName; 551 if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) { 552 ThunkName.insert(ThunkName.find("@"), "$exit_thunk"); 553 } else { 554 ThunkName.append("$exit_thunk"); 555 } 556 Function *GuestExit = 557 Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M); 558 GuestExit->setComdat(M->getOrInsertComdat(ThunkName)); 559 GuestExit->setSection(".wowthk$aa"); 560 GuestExit->setMetadata( 561 "arm64ec_unmangled_name", 562 MDNode::get(M->getContext(), 563 MDString::get(M->getContext(), F->getName()))); 564 GuestExit->setMetadata( 565 "arm64ec_ecmangled_name", 566 MDNode::get(M->getContext(), 567 MDString::get(M->getContext(), *MangledName))); 568 F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {})); 569 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit); 570 IRBuilder<> B(BB); 571 572 // Load the global symbol as a pointer to the check function. 573 Value *GuardFn; 574 if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf")) 575 GuardFn = GuardFnCFGlobal; 576 else 577 GuardFn = GuardFnGlobal; 578 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn); 579 580 // Create new call instruction. The CFGuard check should always be a call, 581 // even if the original CallBase is an Invoke or CallBr instruction. 582 Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes()); 583 CallInst *GuardCheck = B.CreateCall( 584 GuardFnType, GuardCheckLoad, 585 {B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())}); 586 587 // Ensure that the first argument is passed in the correct register. 588 GuardCheck->setCallingConv(CallingConv::CFGuard_Check); 589 590 Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy); 591 SmallVector<Value *> Args; 592 for (Argument &Arg : GuestExit->args()) 593 Args.push_back(&Arg); 594 CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args); 595 Call->setTailCallKind(llvm::CallInst::TCK_MustTail); 596 597 if (Call->getType()->isVoidTy()) 598 B.CreateRetVoid(); 599 else 600 B.CreateRet(Call); 601 602 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet); 603 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg); 604 if (SRetAttr.isValid() && !InRegAttr.isValid()) { 605 GuestExit->addParamAttr(0, SRetAttr); 606 Call->addParamAttr(0, SRetAttr); 607 } 608 609 return GuestExit; 610 } 611 612 // Lower an indirect call with inline code. 613 void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) { 614 assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() && 615 "Only applicable for Windows targets"); 616 617 IRBuilder<> B(CB); 618 Value *CalledOperand = CB->getCalledOperand(); 619 620 // If the indirect call is called within catchpad or cleanuppad, 621 // we need to copy "funclet" bundle of the call. 622 SmallVector<llvm::OperandBundleDef, 1> Bundles; 623 if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet)) 624 Bundles.push_back(OperandBundleDef(*Bundle)); 625 626 // Load the global symbol as a pointer to the check function. 627 Value *GuardFn; 628 if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf")) 629 GuardFn = GuardFnCFGlobal; 630 else 631 GuardFn = GuardFnGlobal; 632 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn); 633 634 // Create new call instruction. The CFGuard check should always be a call, 635 // even if the original CallBase is an Invoke or CallBr instruction. 636 Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes()); 637 CallInst *GuardCheck = 638 B.CreateCall(GuardFnType, GuardCheckLoad, 639 {B.CreateBitCast(CalledOperand, B.getPtrTy()), 640 B.CreateBitCast(Thunk, B.getPtrTy())}, 641 Bundles); 642 643 // Ensure that the first argument is passed in the correct register. 644 GuardCheck->setCallingConv(CallingConv::CFGuard_Check); 645 646 Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType()); 647 CB->setCalledOperand(GuardRetVal); 648 } 649 650 bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) { 651 if (!GenerateThunks) 652 return false; 653 654 M = &Mod; 655 656 // Check if this module has the cfguard flag and read its value. 657 if (auto *MD = 658 mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard"))) 659 cfguard_module_flag = MD->getZExtValue(); 660 661 PtrTy = PointerType::getUnqual(M->getContext()); 662 I64Ty = Type::getInt64Ty(M->getContext()); 663 VoidTy = Type::getVoidTy(M->getContext()); 664 665 GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false); 666 GuardFnPtrType = PointerType::get(GuardFnType, 0); 667 GuardFnCFGlobal = 668 M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType); 669 GuardFnGlobal = 670 M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType); 671 672 SetVector<Function *> DirectCalledFns; 673 for (Function &F : Mod) 674 if (!F.isDeclaration() && 675 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native && 676 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) 677 processFunction(F, DirectCalledFns); 678 679 struct ThunkInfo { 680 Constant *Src; 681 Constant *Dst; 682 unsigned Kind; 683 }; 684 SmallVector<ThunkInfo> ThunkMapping; 685 for (Function &F : Mod) { 686 if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) && 687 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native && 688 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) { 689 if (!F.hasComdat()) 690 F.setComdat(Mod.getOrInsertComdat(F.getName())); 691 ThunkMapping.push_back({&F, buildEntryThunk(&F), 1}); 692 } 693 } 694 for (Function *F : DirectCalledFns) { 695 ThunkMapping.push_back( 696 {F, buildExitThunk(F->getFunctionType(), F->getAttributes()), 4}); 697 if (!F->hasDLLImportStorageClass()) 698 ThunkMapping.push_back({buildGuestExitThunk(F), F, 0}); 699 } 700 701 if (!ThunkMapping.empty()) { 702 SmallVector<Constant *> ThunkMappingArrayElems; 703 for (ThunkInfo &Thunk : ThunkMapping) { 704 ThunkMappingArrayElems.push_back(ConstantStruct::getAnon( 705 {ConstantExpr::getBitCast(Thunk.Src, PtrTy), 706 ConstantExpr::getBitCast(Thunk.Dst, PtrTy), 707 ConstantInt::get(M->getContext(), APInt(32, Thunk.Kind))})); 708 } 709 Constant *ThunkMappingArray = ConstantArray::get( 710 llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(), 711 ThunkMappingArrayElems.size()), 712 ThunkMappingArrayElems); 713 new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false, 714 GlobalValue::ExternalLinkage, ThunkMappingArray, 715 "llvm.arm64ec.symbolmap"); 716 } 717 718 return true; 719 } 720 721 bool AArch64Arm64ECCallLowering::processFunction( 722 Function &F, SetVector<Function *> &DirectCalledFns) { 723 SmallVector<CallBase *, 8> IndirectCalls; 724 725 // For ARM64EC targets, a function definition's name is mangled differently 726 // from the normal symbol. We currently have no representation of this sort 727 // of symbol in IR, so we change the name to the mangled name, then store 728 // the unmangled name as metadata. Later passes that need the unmangled 729 // name (emitting the definition) can grab it from the metadata. 730 // 731 // FIXME: Handle functions with weak linkage? 732 if (F.hasExternalLinkage() || F.hasWeakLinkage() || F.hasLinkOnceLinkage()) { 733 if (std::optional<std::string> MangledName = 734 getArm64ECMangledFunctionName(F.getName().str())) { 735 F.setMetadata("arm64ec_unmangled_name", 736 MDNode::get(M->getContext(), 737 MDString::get(M->getContext(), F.getName()))); 738 if (F.hasComdat() && F.getComdat()->getName() == F.getName()) { 739 Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value()); 740 SmallVector<GlobalObject *> ComdatUsers = 741 to_vector(F.getComdat()->getUsers()); 742 for (GlobalObject *User : ComdatUsers) 743 User->setComdat(MangledComdat); 744 } 745 F.setName(MangledName.value()); 746 } 747 } 748 749 // Iterate over the instructions to find all indirect call/invoke/callbr 750 // instructions. Make a separate list of pointers to indirect 751 // call/invoke/callbr instructions because the original instructions will be 752 // deleted as the checks are added. 753 for (BasicBlock &BB : F) { 754 for (Instruction &I : BB) { 755 auto *CB = dyn_cast<CallBase>(&I); 756 if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 || 757 CB->isInlineAsm()) 758 continue; 759 760 // We need to instrument any call that isn't directly calling an 761 // ARM64 function. 762 // 763 // FIXME: getCalledFunction() fails if there's a bitcast (e.g. 764 // unprototyped functions in C) 765 if (Function *F = CB->getCalledFunction()) { 766 if (!LowerDirectToIndirect || F->hasLocalLinkage() || 767 F->isIntrinsic() || !F->isDeclaration()) 768 continue; 769 770 DirectCalledFns.insert(F); 771 continue; 772 } 773 774 IndirectCalls.push_back(CB); 775 ++Arm64ECCallsLowered; 776 } 777 } 778 779 if (IndirectCalls.empty()) 780 return false; 781 782 for (CallBase *CB : IndirectCalls) 783 lowerCall(CB); 784 785 return true; 786 } 787 788 char AArch64Arm64ECCallLowering::ID = 0; 789 INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering", 790 "AArch64Arm64ECCallLowering", false, false) 791 792 ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() { 793 return new AArch64Arm64ECCallLowering; 794 } 795