1 #include "llvm/ExecutionEngine/Orc/ReOptimizeLayer.h"
2 #include "llvm/ExecutionEngine/Orc/Mangling.h"
3
4 using namespace llvm;
5 using namespace orc;
6
tryStartReoptimize()7 bool ReOptimizeLayer::ReOptMaterializationUnitState::tryStartReoptimize() {
8 std::unique_lock<std::mutex> Lock(Mutex);
9 if (Reoptimizing)
10 return false;
11
12 Reoptimizing = true;
13 return true;
14 }
15
reoptimizeSucceeded()16 void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeSucceeded() {
17 std::unique_lock<std::mutex> Lock(Mutex);
18 assert(Reoptimizing && "Tried to mark unstarted reoptimization as done");
19 Reoptimizing = false;
20 CurVersion++;
21 }
22
reoptimizeFailed()23 void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeFailed() {
24 std::unique_lock<std::mutex> Lock(Mutex);
25 assert(Reoptimizing && "Tried to mark unstarted reoptimization as done");
26 Reoptimizing = false;
27 }
28
reigsterRuntimeFunctions(JITDylib & PlatformJD)29 Error ReOptimizeLayer::reigsterRuntimeFunctions(JITDylib &PlatformJD) {
30 ExecutionSession::JITDispatchHandlerAssociationMap WFs;
31 using ReoptimizeSPSSig = shared::SPSError(uint64_t, uint32_t);
32 WFs[Mangle("__orc_rt_reoptimize_tag")] =
33 ES.wrapAsyncWithSPS<ReoptimizeSPSSig>(this,
34 &ReOptimizeLayer::rt_reoptimize);
35 return ES.registerJITDispatchHandlers(PlatformJD, std::move(WFs));
36 }
37
emit(std::unique_ptr<MaterializationResponsibility> R,ThreadSafeModule TSM)38 void ReOptimizeLayer::emit(std::unique_ptr<MaterializationResponsibility> R,
39 ThreadSafeModule TSM) {
40 auto &JD = R->getTargetJITDylib();
41
42 bool HasNonCallable = false;
43 for (auto &KV : R->getSymbols()) {
44 auto &Flags = KV.second;
45 if (!Flags.isCallable())
46 HasNonCallable = true;
47 }
48
49 if (HasNonCallable) {
50 BaseLayer.emit(std::move(R), std::move(TSM));
51 return;
52 }
53
54 auto &MUState = createMaterializationUnitState(TSM);
55
56 if (auto Err = R->withResourceKeyDo([&](ResourceKey Key) {
57 registerMaterializationUnitResource(Key, MUState);
58 })) {
59 ES.reportError(std::move(Err));
60 R->failMaterialization();
61 return;
62 }
63
64 if (auto Err =
65 ProfilerFunc(*this, MUState.getID(), MUState.getCurVersion(), TSM)) {
66 ES.reportError(std::move(Err));
67 R->failMaterialization();
68 return;
69 }
70
71 auto InitialDests =
72 emitMUImplSymbols(MUState, MUState.getCurVersion(), JD, std::move(TSM));
73 if (!InitialDests) {
74 ES.reportError(InitialDests.takeError());
75 R->failMaterialization();
76 return;
77 }
78
79 RSManager.emitRedirectableSymbols(std::move(R), std::move(*InitialDests));
80 }
81
reoptimizeIfCallFrequent(ReOptimizeLayer & Parent,ReOptMaterializationUnitID MUID,unsigned CurVersion,ThreadSafeModule & TSM)82 Error ReOptimizeLayer::reoptimizeIfCallFrequent(ReOptimizeLayer &Parent,
83 ReOptMaterializationUnitID MUID,
84 unsigned CurVersion,
85 ThreadSafeModule &TSM) {
86 return TSM.withModuleDo([&](Module &M) -> Error {
87 Type *I64Ty = Type::getInt64Ty(M.getContext());
88 GlobalVariable *Counter = new GlobalVariable(
89 M, I64Ty, false, GlobalValue::InternalLinkage,
90 Constant::getNullValue(I64Ty), "__orc_reopt_counter");
91 auto ArgBufferConst = createReoptimizeArgBuffer(M, MUID, CurVersion);
92 if (auto Err = ArgBufferConst.takeError())
93 return Err;
94 GlobalVariable *ArgBuffer =
95 new GlobalVariable(M, (*ArgBufferConst)->getType(), true,
96 GlobalValue::InternalLinkage, (*ArgBufferConst));
97 for (auto &F : M) {
98 if (F.isDeclaration())
99 continue;
100 auto &BB = F.getEntryBlock();
101 auto *IP = &*BB.getFirstInsertionPt();
102 IRBuilder<> IRB(IP);
103 Value *Threshold = ConstantInt::get(I64Ty, CallCountThreshold, true);
104 Value *Cnt = IRB.CreateLoad(I64Ty, Counter);
105 // Use EQ to prevent further reoptimize calls.
106 Value *Cmp = IRB.CreateICmpEQ(Cnt, Threshold);
107 Value *Added = IRB.CreateAdd(Cnt, ConstantInt::get(I64Ty, 1));
108 (void)IRB.CreateStore(Added, Counter);
109 Instruction *SplitTerminator = SplitBlockAndInsertIfThen(Cmp, IP, false);
110 createReoptimizeCall(M, *SplitTerminator, ArgBuffer);
111 }
112 return Error::success();
113 });
114 }
115
116 Expected<SymbolMap>
emitMUImplSymbols(ReOptMaterializationUnitState & MUState,uint32_t Version,JITDylib & JD,ThreadSafeModule TSM)117 ReOptimizeLayer::emitMUImplSymbols(ReOptMaterializationUnitState &MUState,
118 uint32_t Version, JITDylib &JD,
119 ThreadSafeModule TSM) {
120 DenseMap<SymbolStringPtr, SymbolStringPtr> RenamedMap;
121 cantFail(TSM.withModuleDo([&](Module &M) -> Error {
122 MangleAndInterner Mangle(ES, M.getDataLayout());
123 for (auto &F : M)
124 if (!F.isDeclaration()) {
125 std::string NewName =
126 (F.getName() + ".__def__." + Twine(Version)).str();
127 RenamedMap[Mangle(F.getName())] = Mangle(NewName);
128 F.setName(NewName);
129 }
130 return Error::success();
131 }));
132
133 auto RT = JD.createResourceTracker();
134 if (auto Err =
135 JD.define(std::make_unique<BasicIRLayerMaterializationUnit>(
136 BaseLayer, *getManglingOptions(), std::move(TSM)),
137 RT))
138 return Err;
139 MUState.setResourceTracker(RT);
140
141 SymbolLookupSet LookupSymbols;
142 for (auto [K, V] : RenamedMap)
143 LookupSymbols.add(V);
144
145 auto ImplSymbols =
146 ES.lookup({{&JD, JITDylibLookupFlags::MatchAllSymbols}}, LookupSymbols,
147 LookupKind::Static, SymbolState::Resolved);
148 if (auto Err = ImplSymbols.takeError())
149 return Err;
150
151 SymbolMap Result;
152 for (auto [K, V] : RenamedMap)
153 Result[K] = (*ImplSymbols)[V];
154
155 return Result;
156 }
157
rt_reoptimize(SendErrorFn SendResult,ReOptMaterializationUnitID MUID,uint32_t CurVersion)158 void ReOptimizeLayer::rt_reoptimize(SendErrorFn SendResult,
159 ReOptMaterializationUnitID MUID,
160 uint32_t CurVersion) {
161 auto &MUState = getMaterializationUnitState(MUID);
162 if (CurVersion < MUState.getCurVersion() || !MUState.tryStartReoptimize()) {
163 SendResult(Error::success());
164 return;
165 }
166
167 ThreadSafeModule TSM = cloneToNewContext(MUState.getThreadSafeModule());
168 auto OldRT = MUState.getResourceTracker();
169 auto &JD = OldRT->getJITDylib();
170
171 if (auto Err = ReOptFunc(*this, MUID, CurVersion + 1, OldRT, TSM)) {
172 ES.reportError(std::move(Err));
173 MUState.reoptimizeFailed();
174 SendResult(Error::success());
175 return;
176 }
177
178 auto SymbolDests =
179 emitMUImplSymbols(MUState, CurVersion + 1, JD, std::move(TSM));
180 if (!SymbolDests) {
181 ES.reportError(SymbolDests.takeError());
182 MUState.reoptimizeFailed();
183 SendResult(Error::success());
184 return;
185 }
186
187 if (auto Err = RSManager.redirect(JD, std::move(*SymbolDests))) {
188 ES.reportError(std::move(Err));
189 MUState.reoptimizeFailed();
190 SendResult(Error::success());
191 return;
192 }
193
194 MUState.reoptimizeSucceeded();
195 SendResult(Error::success());
196 }
197
createReoptimizeArgBuffer(Module & M,ReOptMaterializationUnitID MUID,uint32_t CurVersion)198 Expected<Constant *> ReOptimizeLayer::createReoptimizeArgBuffer(
199 Module &M, ReOptMaterializationUnitID MUID, uint32_t CurVersion) {
200 size_t ArgBufferSize = SPSReoptimizeArgList::size(MUID, CurVersion);
201 std::vector<char> ArgBuffer(ArgBufferSize);
202 shared::SPSOutputBuffer OB(ArgBuffer.data(), ArgBuffer.size());
203 if (!SPSReoptimizeArgList::serialize(OB, MUID, CurVersion))
204 return make_error<StringError>("Could not serealize args list",
205 inconvertibleErrorCode());
206 return ConstantDataArray::get(M.getContext(), ArrayRef(ArgBuffer));
207 }
208
createReoptimizeCall(Module & M,Instruction & IP,GlobalVariable * ArgBuffer)209 void ReOptimizeLayer::createReoptimizeCall(Module &M, Instruction &IP,
210 GlobalVariable *ArgBuffer) {
211 GlobalVariable *DispatchCtx =
212 M.getGlobalVariable("__orc_rt_jit_dispatch_ctx");
213 if (!DispatchCtx)
214 DispatchCtx = new GlobalVariable(M, PointerType::get(M.getContext(), 0),
215 false, GlobalValue::ExternalLinkage,
216 nullptr, "__orc_rt_jit_dispatch_ctx");
217 GlobalVariable *ReoptimizeTag =
218 M.getGlobalVariable("__orc_rt_reoptimize_tag");
219 if (!ReoptimizeTag)
220 ReoptimizeTag = new GlobalVariable(M, PointerType::get(M.getContext(), 0),
221 false, GlobalValue::ExternalLinkage,
222 nullptr, "__orc_rt_reoptimize_tag");
223 Function *DispatchFunc = M.getFunction("__orc_rt_jit_dispatch");
224 if (!DispatchFunc) {
225 std::vector<Type *> Args = {PointerType::get(M.getContext(), 0),
226 PointerType::get(M.getContext(), 0),
227 PointerType::get(M.getContext(), 0),
228 IntegerType::get(M.getContext(), 64)};
229 FunctionType *FuncTy =
230 FunctionType::get(Type::getVoidTy(M.getContext()), Args, false);
231 DispatchFunc = Function::Create(FuncTy, GlobalValue::ExternalLinkage,
232 "__orc_rt_jit_dispatch", &M);
233 }
234 size_t ArgBufferSizeConst =
235 SPSReoptimizeArgList::size(ReOptMaterializationUnitID{}, uint32_t{});
236 Constant *ArgBufferSize = ConstantInt::get(
237 IntegerType::get(M.getContext(), 64), ArgBufferSizeConst, false);
238 IRBuilder<> IRB(&IP);
239 (void)IRB.CreateCall(DispatchFunc,
240 {DispatchCtx, ReoptimizeTag, ArgBuffer, ArgBufferSize});
241 }
242
243 ReOptimizeLayer::ReOptMaterializationUnitState &
createMaterializationUnitState(const ThreadSafeModule & TSM)244 ReOptimizeLayer::createMaterializationUnitState(const ThreadSafeModule &TSM) {
245 std::unique_lock<std::mutex> Lock(Mutex);
246 ReOptMaterializationUnitID MUID = NextID;
247 MUStates.emplace(MUID,
248 ReOptMaterializationUnitState(MUID, cloneToNewContext(TSM)));
249 ++NextID;
250 return MUStates.at(MUID);
251 }
252
253 ReOptimizeLayer::ReOptMaterializationUnitState &
getMaterializationUnitState(ReOptMaterializationUnitID MUID)254 ReOptimizeLayer::getMaterializationUnitState(ReOptMaterializationUnitID MUID) {
255 std::unique_lock<std::mutex> Lock(Mutex);
256 return MUStates.at(MUID);
257 }
258
registerMaterializationUnitResource(ResourceKey Key,ReOptMaterializationUnitState & State)259 void ReOptimizeLayer::registerMaterializationUnitResource(
260 ResourceKey Key, ReOptMaterializationUnitState &State) {
261 std::unique_lock<std::mutex> Lock(Mutex);
262 MUResources[Key].insert(State.getID());
263 }
264
handleRemoveResources(JITDylib & JD,ResourceKey K)265 Error ReOptimizeLayer::handleRemoveResources(JITDylib &JD, ResourceKey K) {
266 std::unique_lock<std::mutex> Lock(Mutex);
267 for (auto MUID : MUResources[K])
268 MUStates.erase(MUID);
269
270 MUResources.erase(K);
271 return Error::success();
272 }
273
handleTransferResources(JITDylib & JD,ResourceKey DstK,ResourceKey SrcK)274 void ReOptimizeLayer::handleTransferResources(JITDylib &JD, ResourceKey DstK,
275 ResourceKey SrcK) {
276 std::unique_lock<std::mutex> Lock(Mutex);
277 MUResources[DstK].insert_range(MUResources[SrcK]);
278 MUResources.erase(SrcK);
279 }
280