1 //===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// \file 10 /// This file contains the IR transform to lower external or indirect calls for 11 /// the ARM64EC calling convention. Such calls must go through the runtime, so 12 /// we can translate the calling convention for calls into the emulator. 13 /// 14 /// This subsumes Control Flow Guard handling. 15 /// 16 //===----------------------------------------------------------------------===// 17 18 #include "AArch64.h" 19 #include "llvm/ADT/SetVector.h" 20 #include "llvm/ADT/SmallString.h" 21 #include "llvm/ADT/SmallVector.h" 22 #include "llvm/ADT/Statistic.h" 23 #include "llvm/IR/CallingConv.h" 24 #include "llvm/IR/IRBuilder.h" 25 #include "llvm/IR/Instruction.h" 26 #include "llvm/InitializePasses.h" 27 #include "llvm/Object/COFF.h" 28 #include "llvm/Pass.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/TargetParser/Triple.h" 31 32 using namespace llvm; 33 using namespace llvm::object; 34 35 using OperandBundleDef = OperandBundleDefT<Value *>; 36 37 #define DEBUG_TYPE "arm64eccalllowering" 38 39 STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered"); 40 41 static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect", 42 cl::Hidden, cl::init(true)); 43 static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden, 44 cl::init(true)); 45 46 namespace { 47 48 enum class ThunkType { GuestExit, Entry, Exit }; 49 50 class AArch64Arm64ECCallLowering : public ModulePass { 51 public: 52 static char ID; 53 AArch64Arm64ECCallLowering() : ModulePass(ID) { 54 initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry()); 55 } 56 57 Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs); 58 Function *buildEntryThunk(Function *F); 59 void lowerCall(CallBase *CB); 60 Function *buildGuestExitThunk(Function *F); 61 bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns); 62 bool runOnModule(Module &M) override; 63 64 private: 65 int cfguard_module_flag = 0; 66 FunctionType *GuardFnType = nullptr; 67 PointerType *GuardFnPtrType = nullptr; 68 Constant *GuardFnCFGlobal = nullptr; 69 Constant *GuardFnGlobal = nullptr; 70 Module *M = nullptr; 71 72 Type *PtrTy; 73 Type *I64Ty; 74 Type *VoidTy; 75 76 void getThunkType(FunctionType *FT, AttributeList AttrList, ThunkType TT, 77 raw_ostream &Out, FunctionType *&Arm64Ty, 78 FunctionType *&X64Ty); 79 void getThunkRetType(FunctionType *FT, AttributeList AttrList, 80 raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy, 81 SmallVectorImpl<Type *> &Arm64ArgTypes, 82 SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr); 83 void getThunkArgTypes(FunctionType *FT, AttributeList AttrList, ThunkType TT, 84 raw_ostream &Out, 85 SmallVectorImpl<Type *> &Arm64ArgTypes, 86 SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr); 87 void canonicalizeThunkType(Type *T, Align Alignment, bool Ret, 88 uint64_t ArgSizeBytes, raw_ostream &Out, 89 Type *&Arm64Ty, Type *&X64Ty); 90 }; 91 92 } // end anonymous namespace 93 94 void AArch64Arm64ECCallLowering::getThunkType(FunctionType *FT, 95 AttributeList AttrList, 96 ThunkType TT, raw_ostream &Out, 97 FunctionType *&Arm64Ty, 98 FunctionType *&X64Ty) { 99 Out << (TT == ThunkType::Entry ? "$ientry_thunk$cdecl$" 100 : "$iexit_thunk$cdecl$"); 101 102 Type *Arm64RetTy; 103 Type *X64RetTy; 104 105 SmallVector<Type *> Arm64ArgTypes; 106 SmallVector<Type *> X64ArgTypes; 107 108 // The first argument to a thunk is the called function, stored in x9. 109 // For exit thunks, we pass the called function down to the emulator; 110 // for entry/guest exit thunks, we just call the Arm64 function directly. 111 if (TT == ThunkType::Exit) 112 Arm64ArgTypes.push_back(PtrTy); 113 X64ArgTypes.push_back(PtrTy); 114 115 bool HasSretPtr = false; 116 getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes, 117 X64ArgTypes, HasSretPtr); 118 119 getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes, 120 HasSretPtr); 121 122 Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false); 123 124 X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false); 125 } 126 127 void AArch64Arm64ECCallLowering::getThunkArgTypes( 128 FunctionType *FT, AttributeList AttrList, ThunkType TT, raw_ostream &Out, 129 SmallVectorImpl<Type *> &Arm64ArgTypes, 130 SmallVectorImpl<Type *> &X64ArgTypes, bool HasSretPtr) { 131 132 Out << "$"; 133 if (FT->isVarArg()) { 134 // We treat the variadic function's thunk as a normal function 135 // with the following type on the ARM side: 136 // rettype exitthunk( 137 // ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5) 138 // 139 // that can coverage all types of variadic function. 140 // x9 is similar to normal exit thunk, store the called function. 141 // x0-x3 is the arguments be stored in registers. 142 // x4 is the address of the arguments on the stack. 143 // x5 is the size of the arguments on the stack. 144 // 145 // On the x64 side, it's the same except that x5 isn't set. 146 // 147 // If both the ARM and X64 sides are sret, there are only three 148 // arguments in registers. 149 // 150 // If the X64 side is sret, but the ARM side isn't, we pass an extra value 151 // to/from the X64 side, and let SelectionDAG transform it into a memory 152 // location. 153 Out << "varargs"; 154 155 // x0-x3 156 for (int i = HasSretPtr ? 1 : 0; i < 4; i++) { 157 Arm64ArgTypes.push_back(I64Ty); 158 X64ArgTypes.push_back(I64Ty); 159 } 160 161 // x4 162 Arm64ArgTypes.push_back(PtrTy); 163 X64ArgTypes.push_back(PtrTy); 164 // x5 165 Arm64ArgTypes.push_back(I64Ty); 166 if (TT != ThunkType::Entry) { 167 // FIXME: x5 isn't actually used by the x64 side; revisit once we 168 // have proper isel for varargs 169 X64ArgTypes.push_back(I64Ty); 170 } 171 return; 172 } 173 174 unsigned I = 0; 175 if (HasSretPtr) 176 I++; 177 178 if (I == FT->getNumParams()) { 179 Out << "v"; 180 return; 181 } 182 183 for (unsigned E = FT->getNumParams(); I != E; ++I) { 184 #if 0 185 // FIXME: Need more information about argument size; see 186 // https://reviews.llvm.org/D132926 187 uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I); 188 Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne(); 189 #else 190 uint64_t ArgSizeBytes = 0; 191 Align ParamAlign = Align(); 192 #endif 193 Type *Arm64Ty, *X64Ty; 194 canonicalizeThunkType(FT->getParamType(I), ParamAlign, 195 /*Ret*/ false, ArgSizeBytes, Out, Arm64Ty, X64Ty); 196 Arm64ArgTypes.push_back(Arm64Ty); 197 X64ArgTypes.push_back(X64Ty); 198 } 199 } 200 201 void AArch64Arm64ECCallLowering::getThunkRetType( 202 FunctionType *FT, AttributeList AttrList, raw_ostream &Out, 203 Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes, 204 SmallVectorImpl<Type *> &X64ArgTypes, bool &HasSretPtr) { 205 Type *T = FT->getReturnType(); 206 #if 0 207 // FIXME: Need more information about argument size; see 208 // https://reviews.llvm.org/D132926 209 uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes(); 210 #else 211 int64_t ArgSizeBytes = 0; 212 #endif 213 if (T->isVoidTy()) { 214 if (FT->getNumParams()) { 215 auto SRetAttr = AttrList.getParamAttr(0, Attribute::StructRet); 216 auto InRegAttr = AttrList.getParamAttr(0, Attribute::InReg); 217 if (SRetAttr.isValid() && InRegAttr.isValid()) { 218 // sret+inreg indicates a call that returns a C++ class value. This is 219 // actually equivalent to just passing and returning a void* pointer 220 // as the first argument. Translate it that way, instead of trying 221 // to model "inreg" in the thunk's calling convention, to simplify 222 // the rest of the code. 223 Out << "i8"; 224 Arm64RetTy = I64Ty; 225 X64RetTy = I64Ty; 226 return; 227 } 228 if (SRetAttr.isValid()) { 229 // FIXME: Sanity-check the sret type; if it's an integer or pointer, 230 // we'll get screwy mangling/codegen. 231 // FIXME: For large struct types, mangle as an integer argument and 232 // integer return, so we can reuse more thunks, instead of "m" syntax. 233 // (MSVC mangles this case as an integer return with no argument, but 234 // that's a miscompile.) 235 Type *SRetType = SRetAttr.getValueAsType(); 236 Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne(); 237 Type *Arm64Ty, *X64Ty; 238 canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes, 239 Out, Arm64Ty, X64Ty); 240 Arm64RetTy = VoidTy; 241 X64RetTy = VoidTy; 242 Arm64ArgTypes.push_back(FT->getParamType(0)); 243 X64ArgTypes.push_back(FT->getParamType(0)); 244 HasSretPtr = true; 245 return; 246 } 247 } 248 249 Out << "v"; 250 Arm64RetTy = VoidTy; 251 X64RetTy = VoidTy; 252 return; 253 } 254 255 canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out, Arm64RetTy, 256 X64RetTy); 257 if (X64RetTy->isPointerTy()) { 258 // If the X64 type is canonicalized to a pointer, that means it's 259 // passed/returned indirectly. For a return value, that means it's an 260 // sret pointer. 261 X64ArgTypes.push_back(X64RetTy); 262 X64RetTy = VoidTy; 263 } 264 } 265 266 void AArch64Arm64ECCallLowering::canonicalizeThunkType( 267 Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes, raw_ostream &Out, 268 Type *&Arm64Ty, Type *&X64Ty) { 269 if (T->isFloatTy()) { 270 Out << "f"; 271 Arm64Ty = T; 272 X64Ty = T; 273 return; 274 } 275 276 if (T->isDoubleTy()) { 277 Out << "d"; 278 Arm64Ty = T; 279 X64Ty = T; 280 return; 281 } 282 283 if (T->isFloatingPointTy()) { 284 report_fatal_error( 285 "Only 32 and 64 bit floating points are supported for ARM64EC thunks"); 286 } 287 288 auto &DL = M->getDataLayout(); 289 290 if (auto *StructTy = dyn_cast<StructType>(T)) 291 if (StructTy->getNumElements() == 1) 292 T = StructTy->getElementType(0); 293 294 if (T->isArrayTy()) { 295 Type *ElementTy = T->getArrayElementType(); 296 uint64_t ElementCnt = T->getArrayNumElements(); 297 uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8; 298 uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes; 299 if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) { 300 Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes; 301 if (Alignment.value() >= 16 && !Ret) 302 Out << "a" << Alignment.value(); 303 Arm64Ty = T; 304 if (TotalSizeBytes <= 8) { 305 // Arm64 returns small structs of float/double in float registers; 306 // X64 uses RAX. 307 X64Ty = llvm::Type::getIntNTy(M->getContext(), TotalSizeBytes * 8); 308 } else { 309 // Struct is passed directly on Arm64, but indirectly on X64. 310 X64Ty = PtrTy; 311 } 312 return; 313 } else if (T->isFloatingPointTy()) { 314 report_fatal_error("Only 32 and 64 bit floating points are supported for " 315 "ARM64EC thunks"); 316 } 317 } 318 319 if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) { 320 Out << "i8"; 321 Arm64Ty = I64Ty; 322 X64Ty = I64Ty; 323 return; 324 } 325 326 unsigned TypeSize = ArgSizeBytes; 327 if (TypeSize == 0) 328 TypeSize = DL.getTypeSizeInBits(T) / 8; 329 Out << "m"; 330 if (TypeSize != 4) 331 Out << TypeSize; 332 if (Alignment.value() >= 16 && !Ret) 333 Out << "a" << Alignment.value(); 334 // FIXME: Try to canonicalize Arm64Ty more thoroughly? 335 Arm64Ty = T; 336 if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) { 337 // Pass directly in an integer register 338 X64Ty = llvm::Type::getIntNTy(M->getContext(), TypeSize * 8); 339 } else { 340 // Passed directly on Arm64, but indirectly on X64. 341 X64Ty = PtrTy; 342 } 343 } 344 345 // This function builds the "exit thunk", a function which translates 346 // arguments and return values when calling x64 code from AArch64 code. 347 Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT, 348 AttributeList Attrs) { 349 SmallString<256> ExitThunkName; 350 llvm::raw_svector_ostream ExitThunkStream(ExitThunkName); 351 FunctionType *Arm64Ty, *X64Ty; 352 getThunkType(FT, Attrs, ThunkType::Exit, ExitThunkStream, Arm64Ty, X64Ty); 353 if (Function *F = M->getFunction(ExitThunkName)) 354 return F; 355 356 Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0, 357 ExitThunkName, M); 358 F->setCallingConv(CallingConv::ARM64EC_Thunk_Native); 359 F->setSection(".wowthk$aa"); 360 F->setComdat(M->getOrInsertComdat(ExitThunkName)); 361 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.) 362 F->addFnAttr("frame-pointer", "all"); 363 // Only copy sret from the first argument. For C++ instance methods, clang can 364 // stick an sret marking on a later argument, but it doesn't actually affect 365 // the ABI, so we can omit it. This avoids triggering a verifier assertion. 366 if (FT->getNumParams()) { 367 auto SRet = Attrs.getParamAttr(0, Attribute::StructRet); 368 auto InReg = Attrs.getParamAttr(0, Attribute::InReg); 369 if (SRet.isValid() && !InReg.isValid()) 370 F->addParamAttr(1, SRet); 371 } 372 // FIXME: Copy anything other than sret? Shouldn't be necessary for normal 373 // C ABI, but might show up in other cases. 374 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F); 375 IRBuilder<> IRB(BB); 376 Value *CalleePtr = 377 M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy); 378 Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr); 379 auto &DL = M->getDataLayout(); 380 SmallVector<Value *> Args; 381 382 // Pass the called function in x9. 383 Args.push_back(F->arg_begin()); 384 385 Type *RetTy = Arm64Ty->getReturnType(); 386 if (RetTy != X64Ty->getReturnType()) { 387 // If the return type is an array or struct, translate it. Values of size 388 // 8 or less go into RAX; bigger values go into memory, and we pass a 389 // pointer. 390 if (DL.getTypeStoreSize(RetTy) > 8) { 391 Args.push_back(IRB.CreateAlloca(RetTy)); 392 } 393 } 394 395 for (auto &Arg : make_range(F->arg_begin() + 1, F->arg_end())) { 396 // Translate arguments from AArch64 calling convention to x86 calling 397 // convention. 398 // 399 // For simple types, we don't need to do any translation: they're 400 // represented the same way. (Implicit sign extension is not part of 401 // either convention.) 402 // 403 // The big thing we have to worry about is struct types... but 404 // fortunately AArch64 clang is pretty friendly here: the cases that need 405 // translation are always passed as a struct or array. (If we run into 406 // some cases where this doesn't work, we can teach clang to mark it up 407 // with an attribute.) 408 // 409 // The first argument is the called function, stored in x9. 410 if (Arg.getType()->isArrayTy() || Arg.getType()->isStructTy() || 411 DL.getTypeStoreSize(Arg.getType()) > 8) { 412 Value *Mem = IRB.CreateAlloca(Arg.getType()); 413 IRB.CreateStore(&Arg, Mem); 414 if (DL.getTypeStoreSize(Arg.getType()) <= 8) { 415 Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType())); 416 Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy))); 417 } else 418 Args.push_back(Mem); 419 } else { 420 Args.push_back(&Arg); 421 } 422 } 423 // FIXME: Transfer necessary attributes? sret? anything else? 424 425 Callee = IRB.CreateBitCast(Callee, PtrTy); 426 CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args); 427 Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64); 428 429 Value *RetVal = Call; 430 if (RetTy != X64Ty->getReturnType()) { 431 // If we rewrote the return type earlier, convert the return value to 432 // the proper type. 433 if (DL.getTypeStoreSize(RetTy) > 8) { 434 RetVal = IRB.CreateLoad(RetTy, Args[1]); 435 } else { 436 Value *CastAlloca = IRB.CreateAlloca(RetTy); 437 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy)); 438 RetVal = IRB.CreateLoad(RetTy, CastAlloca); 439 } 440 } 441 442 if (RetTy->isVoidTy()) 443 IRB.CreateRetVoid(); 444 else 445 IRB.CreateRet(RetVal); 446 return F; 447 } 448 449 // This function builds the "entry thunk", a function which translates 450 // arguments and return values when calling AArch64 code from x64 code. 451 Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) { 452 SmallString<256> EntryThunkName; 453 llvm::raw_svector_ostream EntryThunkStream(EntryThunkName); 454 FunctionType *Arm64Ty, *X64Ty; 455 getThunkType(F->getFunctionType(), F->getAttributes(), ThunkType::Entry, 456 EntryThunkStream, Arm64Ty, X64Ty); 457 if (Function *F = M->getFunction(EntryThunkName)) 458 return F; 459 460 Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0, 461 EntryThunkName, M); 462 Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64); 463 Thunk->setSection(".wowthk$aa"); 464 Thunk->setComdat(M->getOrInsertComdat(EntryThunkName)); 465 // Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.) 466 Thunk->addFnAttr("frame-pointer", "all"); 467 468 auto &DL = M->getDataLayout(); 469 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk); 470 IRBuilder<> IRB(BB); 471 472 Type *RetTy = Arm64Ty->getReturnType(); 473 Type *X64RetType = X64Ty->getReturnType(); 474 475 bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy(); 476 unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1; 477 unsigned PassthroughArgSize = F->isVarArg() ? 5 : Thunk->arg_size(); 478 479 // Translate arguments to call. 480 SmallVector<Value *> Args; 481 for (unsigned i = ThunkArgOffset, e = PassthroughArgSize; i != e; ++i) { 482 Value *Arg = Thunk->getArg(i); 483 Type *ArgTy = Arm64Ty->getParamType(i - ThunkArgOffset); 484 if (ArgTy->isArrayTy() || ArgTy->isStructTy() || 485 DL.getTypeStoreSize(ArgTy) > 8) { 486 // Translate array/struct arguments to the expected type. 487 if (DL.getTypeStoreSize(ArgTy) <= 8) { 488 Value *CastAlloca = IRB.CreateAlloca(ArgTy); 489 IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy)); 490 Arg = IRB.CreateLoad(ArgTy, CastAlloca); 491 } else { 492 Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy)); 493 } 494 } 495 Args.push_back(Arg); 496 } 497 498 if (F->isVarArg()) { 499 // The 5th argument to variadic entry thunks is used to model the x64 sp 500 // which is passed to the thunk in x4, this can be passed to the callee as 501 // the variadic argument start address after skipping over the 32 byte 502 // shadow store. 503 504 // The EC thunk CC will assign any argument marked as InReg to x4. 505 Thunk->addParamAttr(5, Attribute::InReg); 506 Value *Arg = Thunk->getArg(5); 507 Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20)); 508 Args.push_back(Arg); 509 510 // Pass in a zero variadic argument size (in x5). 511 Args.push_back(IRB.getInt64(0)); 512 } 513 514 // Call the function passed to the thunk. 515 Value *Callee = Thunk->getArg(0); 516 Callee = IRB.CreateBitCast(Callee, PtrTy); 517 CallInst *Call = IRB.CreateCall(Arm64Ty, Callee, Args); 518 519 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet); 520 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg); 521 if (SRetAttr.isValid() && !InRegAttr.isValid()) { 522 Thunk->addParamAttr(1, SRetAttr); 523 Call->addParamAttr(0, SRetAttr); 524 } 525 526 Value *RetVal = Call; 527 if (TransformDirectToSRet) { 528 IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy)); 529 } else if (X64RetType != RetTy) { 530 Value *CastAlloca = IRB.CreateAlloca(X64RetType); 531 IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy)); 532 RetVal = IRB.CreateLoad(X64RetType, CastAlloca); 533 } 534 535 // Return to the caller. Note that the isel has code to translate this 536 // "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we 537 // could emit a tail call here, but that would require a dedicated calling 538 // convention, which seems more complicated overall.) 539 if (X64RetType->isVoidTy()) 540 IRB.CreateRetVoid(); 541 else 542 IRB.CreateRet(RetVal); 543 544 return Thunk; 545 } 546 547 // Builds the "guest exit thunk", a helper to call a function which may or may 548 // not be an exit thunk. (We optimistically assume non-dllimport function 549 // declarations refer to functions defined in AArch64 code; if the linker 550 // can't prove that, we use this routine instead.) 551 Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) { 552 llvm::raw_null_ostream NullThunkName; 553 FunctionType *Arm64Ty, *X64Ty; 554 getThunkType(F->getFunctionType(), F->getAttributes(), ThunkType::GuestExit, 555 NullThunkName, Arm64Ty, X64Ty); 556 auto MangledName = getArm64ECMangledFunctionName(F->getName().str()); 557 assert(MangledName && "Can't guest exit to function that's already native"); 558 std::string ThunkName = *MangledName; 559 if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) { 560 ThunkName.insert(ThunkName.find("@"), "$exit_thunk"); 561 } else { 562 ThunkName.append("$exit_thunk"); 563 } 564 Function *GuestExit = 565 Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M); 566 GuestExit->setComdat(M->getOrInsertComdat(ThunkName)); 567 GuestExit->setSection(".wowthk$aa"); 568 GuestExit->setMetadata( 569 "arm64ec_unmangled_name", 570 MDNode::get(M->getContext(), 571 MDString::get(M->getContext(), F->getName()))); 572 GuestExit->setMetadata( 573 "arm64ec_ecmangled_name", 574 MDNode::get(M->getContext(), 575 MDString::get(M->getContext(), *MangledName))); 576 F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {})); 577 BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit); 578 IRBuilder<> B(BB); 579 580 // Load the global symbol as a pointer to the check function. 581 Value *GuardFn; 582 if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf")) 583 GuardFn = GuardFnCFGlobal; 584 else 585 GuardFn = GuardFnGlobal; 586 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn); 587 588 // Create new call instruction. The CFGuard check should always be a call, 589 // even if the original CallBase is an Invoke or CallBr instruction. 590 Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes()); 591 CallInst *GuardCheck = B.CreateCall( 592 GuardFnType, GuardCheckLoad, 593 {B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())}); 594 595 // Ensure that the first argument is passed in the correct register. 596 GuardCheck->setCallingConv(CallingConv::CFGuard_Check); 597 598 Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy); 599 SmallVector<Value *> Args; 600 for (Argument &Arg : GuestExit->args()) 601 Args.push_back(&Arg); 602 CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args); 603 Call->setTailCallKind(llvm::CallInst::TCK_MustTail); 604 605 if (Call->getType()->isVoidTy()) 606 B.CreateRetVoid(); 607 else 608 B.CreateRet(Call); 609 610 auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet); 611 auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg); 612 if (SRetAttr.isValid() && !InRegAttr.isValid()) { 613 GuestExit->addParamAttr(0, SRetAttr); 614 Call->addParamAttr(0, SRetAttr); 615 } 616 617 return GuestExit; 618 } 619 620 // Lower an indirect call with inline code. 621 void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) { 622 assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() && 623 "Only applicable for Windows targets"); 624 625 IRBuilder<> B(CB); 626 Value *CalledOperand = CB->getCalledOperand(); 627 628 // If the indirect call is called within catchpad or cleanuppad, 629 // we need to copy "funclet" bundle of the call. 630 SmallVector<llvm::OperandBundleDef, 1> Bundles; 631 if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet)) 632 Bundles.push_back(OperandBundleDef(*Bundle)); 633 634 // Load the global symbol as a pointer to the check function. 635 Value *GuardFn; 636 if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf")) 637 GuardFn = GuardFnCFGlobal; 638 else 639 GuardFn = GuardFnGlobal; 640 LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn); 641 642 // Create new call instruction. The CFGuard check should always be a call, 643 // even if the original CallBase is an Invoke or CallBr instruction. 644 Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes()); 645 CallInst *GuardCheck = 646 B.CreateCall(GuardFnType, GuardCheckLoad, 647 {B.CreateBitCast(CalledOperand, B.getPtrTy()), 648 B.CreateBitCast(Thunk, B.getPtrTy())}, 649 Bundles); 650 651 // Ensure that the first argument is passed in the correct register. 652 GuardCheck->setCallingConv(CallingConv::CFGuard_Check); 653 654 Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType()); 655 CB->setCalledOperand(GuardRetVal); 656 } 657 658 bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) { 659 if (!GenerateThunks) 660 return false; 661 662 M = &Mod; 663 664 // Check if this module has the cfguard flag and read its value. 665 if (auto *MD = 666 mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard"))) 667 cfguard_module_flag = MD->getZExtValue(); 668 669 PtrTy = PointerType::getUnqual(M->getContext()); 670 I64Ty = Type::getInt64Ty(M->getContext()); 671 VoidTy = Type::getVoidTy(M->getContext()); 672 673 GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false); 674 GuardFnPtrType = PointerType::get(GuardFnType, 0); 675 GuardFnCFGlobal = 676 M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType); 677 GuardFnGlobal = 678 M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType); 679 680 SetVector<Function *> DirectCalledFns; 681 for (Function &F : Mod) 682 if (!F.isDeclaration() && 683 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native && 684 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) 685 processFunction(F, DirectCalledFns); 686 687 struct ThunkInfo { 688 Constant *Src; 689 Constant *Dst; 690 unsigned Kind; 691 }; 692 SmallVector<ThunkInfo> ThunkMapping; 693 for (Function &F : Mod) { 694 if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) && 695 F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native && 696 F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) { 697 if (!F.hasComdat()) 698 F.setComdat(Mod.getOrInsertComdat(F.getName())); 699 ThunkMapping.push_back({&F, buildEntryThunk(&F), 1}); 700 } 701 } 702 for (Function *F : DirectCalledFns) { 703 ThunkMapping.push_back( 704 {F, buildExitThunk(F->getFunctionType(), F->getAttributes()), 4}); 705 if (!F->hasDLLImportStorageClass()) 706 ThunkMapping.push_back({buildGuestExitThunk(F), F, 0}); 707 } 708 709 if (!ThunkMapping.empty()) { 710 SmallVector<Constant *> ThunkMappingArrayElems; 711 for (ThunkInfo &Thunk : ThunkMapping) { 712 ThunkMappingArrayElems.push_back(ConstantStruct::getAnon( 713 {ConstantExpr::getBitCast(Thunk.Src, PtrTy), 714 ConstantExpr::getBitCast(Thunk.Dst, PtrTy), 715 ConstantInt::get(M->getContext(), APInt(32, Thunk.Kind))})); 716 } 717 Constant *ThunkMappingArray = ConstantArray::get( 718 llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(), 719 ThunkMappingArrayElems.size()), 720 ThunkMappingArrayElems); 721 new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false, 722 GlobalValue::ExternalLinkage, ThunkMappingArray, 723 "llvm.arm64ec.symbolmap"); 724 } 725 726 return true; 727 } 728 729 bool AArch64Arm64ECCallLowering::processFunction( 730 Function &F, SetVector<Function *> &DirectCalledFns) { 731 SmallVector<CallBase *, 8> IndirectCalls; 732 733 // For ARM64EC targets, a function definition's name is mangled differently 734 // from the normal symbol. We currently have no representation of this sort 735 // of symbol in IR, so we change the name to the mangled name, then store 736 // the unmangled name as metadata. Later passes that need the unmangled 737 // name (emitting the definition) can grab it from the metadata. 738 // 739 // FIXME: Handle functions with weak linkage? 740 if (F.hasExternalLinkage() || F.hasWeakLinkage() || F.hasLinkOnceLinkage()) { 741 if (std::optional<std::string> MangledName = 742 getArm64ECMangledFunctionName(F.getName().str())) { 743 F.setMetadata("arm64ec_unmangled_name", 744 MDNode::get(M->getContext(), 745 MDString::get(M->getContext(), F.getName()))); 746 if (F.hasComdat() && F.getComdat()->getName() == F.getName()) { 747 Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value()); 748 SmallVector<GlobalObject *> ComdatUsers = 749 to_vector(F.getComdat()->getUsers()); 750 for (GlobalObject *User : ComdatUsers) 751 User->setComdat(MangledComdat); 752 } 753 F.setName(MangledName.value()); 754 } 755 } 756 757 // Iterate over the instructions to find all indirect call/invoke/callbr 758 // instructions. Make a separate list of pointers to indirect 759 // call/invoke/callbr instructions because the original instructions will be 760 // deleted as the checks are added. 761 for (BasicBlock &BB : F) { 762 for (Instruction &I : BB) { 763 auto *CB = dyn_cast<CallBase>(&I); 764 if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 || 765 CB->isInlineAsm()) 766 continue; 767 768 // We need to instrument any call that isn't directly calling an 769 // ARM64 function. 770 // 771 // FIXME: getCalledFunction() fails if there's a bitcast (e.g. 772 // unprototyped functions in C) 773 if (Function *F = CB->getCalledFunction()) { 774 if (!LowerDirectToIndirect || F->hasLocalLinkage() || 775 F->isIntrinsic() || !F->isDeclaration()) 776 continue; 777 778 DirectCalledFns.insert(F); 779 continue; 780 } 781 782 IndirectCalls.push_back(CB); 783 ++Arm64ECCallsLowered; 784 } 785 } 786 787 if (IndirectCalls.empty()) 788 return false; 789 790 for (CallBase *CB : IndirectCalls) 791 lowerCall(CB); 792 793 return true; 794 } 795 796 char AArch64Arm64ECCallLowering::ID = 0; 797 INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering", 798 "AArch64Arm64ECCallLowering", false, false) 799 800 ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() { 801 return new AArch64Arm64ECCallLowering; 802 } 803