diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2016-07-23 20:41:05 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2016-07-23 20:41:05 +0000 |
commit | 01095a5d43bbfde13731688ddcf6048ebb8b7721 (patch) | |
tree | 4def12e759965de927d963ac65840d663ef9d1ea /include/llvm/ExecutionEngine/Orc | |
parent | f0f4822ed4b66e3579e92a89f368f8fb860e218e (diff) | |
download | src-01095a5d43bbfde13731688ddcf6048ebb8b7721.tar.gz src-01095a5d43bbfde13731688ddcf6048ebb8b7721.zip |
Vendor import of llvm release_39 branch r276489:vendor/llvm/llvm-release_39-r276489
Notes
Notes:
svn path=/vendor/llvm/dist/; revision=303231
svn path=/vendor/llvm/llvm-release_39-r276489/; revision=303232; tag=vendor/llvm/llvm-release_39-r276489
Diffstat (limited to 'include/llvm/ExecutionEngine/Orc')
17 files changed, 1736 insertions, 980 deletions
diff --git a/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h b/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h index 84af4728b350..ef88dd03ad4f 100644 --- a/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h +++ b/include/llvm/ExecutionEngine/Orc/CompileOnDemandLayer.h @@ -19,12 +19,12 @@ #include "LambdaResolver.h" #include "LogicalDylib.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" #include "llvm/Transforms/Utils/Cloning.h" #include <list> #include <memory> #include <set> - -#include "llvm/Support/Debug.h" +#include <utility> namespace llvm { namespace orc { @@ -46,7 +46,7 @@ private: class LambdaMaterializer final : public ValueMaterializer { public: LambdaMaterializer(MaterializerFtor M) : M(std::move(M)) {} - Value *materializeDeclFor(Value *V) final { return M(V); } + Value *materialize(Value *V) final { return M(V); } private: MaterializerFtor M; @@ -145,7 +145,7 @@ private: return *this; } - SymbolResolverFtor ExternalSymbolResolver; + std::unique_ptr<RuntimeDyld::SymbolResolver> ExternalSymbolResolver; std::unique_ptr<ResourceOwner<RuntimeDyld::MemoryManager>> MemMgr; ModuleAdderFtor ModuleAdder; }; @@ -173,7 +173,7 @@ public: CompileCallbackMgrT &CallbackMgr, IndirectStubsManagerBuilderT CreateIndirectStubsManager, bool CloneStubsIntoPartitions = true) - : BaseLayer(BaseLayer), Partition(Partition), + : BaseLayer(BaseLayer), Partition(std::move(Partition)), CompileCallbackMgr(CallbackMgr), CreateIndirectStubsManager(std::move(CreateIndirectStubsManager)), CloneStubsIntoPartitions(CloneStubsIntoPartitions) {} @@ -188,10 +188,7 @@ public: LogicalDylibs.push_back(CODLogicalDylib(BaseLayer)); auto &LDResources = LogicalDylibs.back().getDylibResources(); - LDResources.ExternalSymbolResolver = - [Resolver](const std::string &Name) { - return Resolver->findSymbol(Name); - }; + LDResources.ExternalSymbolResolver = std::move(Resolver); auto &MemMgrRef = *MemMgr; LDResources.MemMgr = @@ -256,14 +253,8 @@ private: Module &SrcM = LMResources.SourceModule->getResource(); - // Create the GlobalValues module. + // Create stub functions. const DataLayout &DL = SrcM.getDataLayout(); - auto GVsM = llvm::make_unique<Module>((SrcM.getName() + ".globals").str(), - SrcM.getContext()); - GVsM->setDataLayout(DL); - - // Create function stubs. - ValueToValueMapTy VMap; { typename IndirectStubsMgrT::StubInitsMap StubInits; for (auto &F : SrcM) { @@ -295,6 +286,19 @@ private: assert(!EC && "Error generating stubs"); } + // If this module doesn't contain any globals or aliases we can bail out + // early and avoid the overhead of creating and managing an empty globals + // module. + if (SrcM.global_empty() && SrcM.alias_empty()) + return; + + // Create the GlobalValues module. + auto GVsM = llvm::make_unique<Module>((SrcM.getName() + ".globals").str(), + SrcM.getContext()); + GVsM->setDataLayout(DL); + + ValueToValueMapTy VMap; + // Clone global variable decls. for (auto &GV : SrcM.globals()) if (!GV.isDeclaration() && !VMap.count(&GV)) @@ -356,16 +360,17 @@ private: [&LD, LMH](const std::string &Name) { auto &LMResources = LD.getLogicalModuleResources(LMH); if (auto Sym = LMResources.StubsMgr->findStub(Name, false)) - return RuntimeDyld::SymbolInfo(Sym.getAddress(), Sym.getFlags()); - return LD.getDylibResources().ExternalSymbolResolver(Name); + return Sym.toRuntimeDyldSymbol(); + auto &LDResolver = LD.getDylibResources().ExternalSymbolResolver; + return LDResolver->findSymbolInLogicalDylib(Name); }, - [](const std::string &Name) { - return RuntimeDyld::SymbolInfo(nullptr); + [&LD](const std::string &Name) { + auto &LDResolver = LD.getDylibResources().ExternalSymbolResolver; + return LDResolver->findSymbol(Name); }); - auto GVsH = - LD.getDylibResources().ModuleAdder(BaseLayer, std::move(GVsM), - std::move(GVsResolver)); + auto GVsH = LD.getDylibResources().ModuleAdder(BaseLayer, std::move(GVsM), + std::move(GVsResolver)); LD.addToLogicalModule(LMH, GVsH); } @@ -481,20 +486,18 @@ private: // Create memory manager and symbol resolver. auto Resolver = createLambdaResolver( [this, &LD, LMH](const std::string &Name) { - if (auto Symbol = LD.findSymbolInternally(LMH, Name)) - return RuntimeDyld::SymbolInfo(Symbol.getAddress(), - Symbol.getFlags()); - return LD.getDylibResources().ExternalSymbolResolver(Name); + if (auto Sym = LD.findSymbolInternally(LMH, Name)) + return Sym.toRuntimeDyldSymbol(); + auto &LDResolver = LD.getDylibResources().ExternalSymbolResolver; + return LDResolver->findSymbolInLogicalDylib(Name); }, - [this, &LD, LMH](const std::string &Name) { - if (auto Symbol = LD.findSymbolInternally(LMH, Name)) - return RuntimeDyld::SymbolInfo(Symbol.getAddress(), - Symbol.getFlags()); - return RuntimeDyld::SymbolInfo(nullptr); + [this, &LD](const std::string &Name) { + auto &LDResolver = LD.getDylibResources().ExternalSymbolResolver; + return LDResolver->findSymbol(Name); }); return LD.getDylibResources().ModuleAdder(BaseLayer, std::move(M), - std::move(Resolver)); + std::move(Resolver)); } BaseLayerT &BaseLayer; diff --git a/include/llvm/ExecutionEngine/Orc/CompileUtils.h b/include/llvm/ExecutionEngine/Orc/CompileUtils.h index 1e7d211196f5..ce0864fbd9c9 100644 --- a/include/llvm/ExecutionEngine/Orc/CompileUtils.h +++ b/include/llvm/ExecutionEngine/Orc/CompileUtils.h @@ -42,12 +42,13 @@ public: PM.run(M); std::unique_ptr<MemoryBuffer> ObjBuffer( new ObjectMemoryBuffer(std::move(ObjBufferSV))); - ErrorOr<std::unique_ptr<object::ObjectFile>> Obj = + Expected<std::unique_ptr<object::ObjectFile>> Obj = object::ObjectFile::createObjectFile(ObjBuffer->getMemBufferRef()); - // TODO: Actually report errors helpfully. typedef object::OwningBinary<object::ObjectFile> OwningObj; if (Obj) return OwningObj(std::move(*Obj), std::move(ObjBuffer)); + // TODO: Actually report errors helpfully. + consumeError(Obj.takeError()); return OwningObj(nullptr, nullptr); } diff --git a/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h b/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h index e4bed95fdabf..e6ce18a42b8b 100644 --- a/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h +++ b/include/llvm/ExecutionEngine/Orc/IRCompileLayer.h @@ -37,9 +37,6 @@ public: private: typedef typename BaseLayerT::ObjSetHandleT ObjSetHandleT; - typedef std::vector<std::unique_ptr<object::ObjectFile>> OwningObjectVec; - typedef std::vector<std::unique_ptr<MemoryBuffer>> OwningBufferVec; - public: /// @brief Handle to a set of compiled modules. typedef ObjSetHandleT ModuleSetHandleT; @@ -62,28 +59,29 @@ public: ModuleSetHandleT addModuleSet(ModuleSetT Ms, MemoryManagerPtrT MemMgr, SymbolResolverPtrT Resolver) { - OwningObjectVec Objects; - OwningBufferVec Buffers; + std::vector<std::unique_ptr<object::OwningBinary<object::ObjectFile>>> + Objects; for (const auto &M : Ms) { - std::unique_ptr<object::ObjectFile> Object; - std::unique_ptr<MemoryBuffer> Buffer; + auto Object = + llvm::make_unique<object::OwningBinary<object::ObjectFile>>(); if (ObjCache) - std::tie(Object, Buffer) = tryToLoadFromObjectCache(*M).takeBinary(); + *Object = tryToLoadFromObjectCache(*M); - if (!Object) { - std::tie(Object, Buffer) = Compile(*M).takeBinary(); + if (!Object->getBinary()) { + *Object = Compile(*M); if (ObjCache) - ObjCache->notifyObjectCompiled(&*M, Buffer->getMemBufferRef()); + ObjCache->notifyObjectCompiled(&*M, + Object->getBinary()->getMemoryBufferRef()); } Objects.push_back(std::move(Object)); - Buffers.push_back(std::move(Buffer)); } ModuleSetHandleT H = - BaseLayer.addObjectSet(Objects, std::move(MemMgr), std::move(Resolver)); + BaseLayer.addObjectSet(std::move(Objects), std::move(MemMgr), + std::move(Resolver)); return H; } @@ -126,10 +124,13 @@ private: if (!ObjBuffer) return object::OwningBinary<object::ObjectFile>(); - ErrorOr<std::unique_ptr<object::ObjectFile>> Obj = + Expected<std::unique_ptr<object::ObjectFile>> Obj = object::ObjectFile::createObjectFile(ObjBuffer->getMemBufferRef()); - if (!Obj) + if (!Obj) { + // TODO: Actually report errors helpfully. + consumeError(Obj.takeError()); return object::OwningBinary<object::ObjectFile>(); + } return object::OwningBinary<object::ObjectFile>(std::move(*Obj), std::move(ObjBuffer)); diff --git a/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h b/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h index e17630fa05ff..51172c51e136 100644 --- a/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h +++ b/include/llvm/ExecutionEngine/Orc/IndirectionUtils.h @@ -16,14 +16,12 @@ #include "JITSymbol.h" #include "LambdaResolver.h" -#include "llvm/ADT/DenseSet.h" #include "llvm/ExecutionEngine/RuntimeDyld.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Mangler.h" #include "llvm/IR/Module.h" -#include "llvm/Transforms/Utils/ValueMapper.h" #include "llvm/Support/Process.h" -#include <sstream> +#include "llvm/Transforms/Utils/ValueMapper.h" namespace llvm { namespace orc { @@ -31,7 +29,6 @@ namespace orc { /// @brief Target-independent base class for compile callback management. class JITCompileCallbackManager { public: - typedef std::function<TargetAddress()> CompileFtor; /// @brief Handle to a newly created compile callback. Can be used to get an @@ -40,12 +37,13 @@ public: class CompileCallbackInfo { public: CompileCallbackInfo(TargetAddress Addr, CompileFtor &Compile) - : Addr(Addr), Compile(Compile) {} + : Addr(Addr), Compile(Compile) {} TargetAddress getAddress() const { return Addr; } void setCompileAction(CompileFtor Compile) { this->Compile = std::move(Compile); } + private: TargetAddress Addr; CompileFtor &Compile; @@ -55,7 +53,7 @@ public: /// @param ErrorHandlerAddress The address of an error handler in the target /// process to be used if a compile callback fails. JITCompileCallbackManager(TargetAddress ErrorHandlerAddress) - : ErrorHandlerAddress(ErrorHandlerAddress) {} + : ErrorHandlerAddress(ErrorHandlerAddress) {} virtual ~JITCompileCallbackManager() {} @@ -71,8 +69,10 @@ public: // Found a callback handler. Yank this trampoline out of the active list and // put it back in the available trampolines list, then try to run the // handler's compile and update actions. - // Moving the trampoline ID back to the available list first means there's at - // least one available trampoline if the compile action triggers a request for + // Moving the trampoline ID back to the available list first means there's + // at + // least one available trampoline if the compile action triggers a request + // for // a new one. auto Compile = std::move(I->second); ActiveTrampolines.erase(I); @@ -118,7 +118,6 @@ protected: std::vector<TargetAddress> AvailableTrampolines; private: - TargetAddress getAvailableTrampolineAddr() { if (this->AvailableTrampolines.empty()) grow(); @@ -139,20 +138,17 @@ private: template <typename TargetT> class LocalJITCompileCallbackManager : public JITCompileCallbackManager { public: - /// @brief Construct a InProcessJITCompileCallbackManager. /// @param ErrorHandlerAddress The address of an error handler in the target /// process to be used if a compile callback fails. LocalJITCompileCallbackManager(TargetAddress ErrorHandlerAddress) - : JITCompileCallbackManager(ErrorHandlerAddress) { + : JITCompileCallbackManager(ErrorHandlerAddress) { /// Set up the resolver block. std::error_code EC; - ResolverBlock = - sys::OwningMemoryBlock( - sys::Memory::allocateMappedMemory(TargetT::ResolverCodeSize, nullptr, - sys::Memory::MF_READ | - sys::Memory::MF_WRITE, EC)); + ResolverBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory( + TargetT::ResolverCodeSize, nullptr, + sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC)); assert(!EC && "Failed to allocate resolver block"); TargetT::writeResolverCode(static_cast<uint8_t *>(ResolverBlock.base()), @@ -165,13 +161,11 @@ public: } private: - static TargetAddress reenter(void *CCMgr, void *TrampolineId) { JITCompileCallbackManager *Mgr = - static_cast<JITCompileCallbackManager*>(CCMgr); + static_cast<JITCompileCallbackManager *>(CCMgr); return Mgr->executeCompileCallback( - static_cast<TargetAddress>( - reinterpret_cast<uintptr_t>(TrampolineId))); + static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineId))); } void grow() override { @@ -179,18 +173,16 @@ private: std::error_code EC; auto TrampolineBlock = - sys::OwningMemoryBlock( - sys::Memory::allocateMappedMemory(sys::Process::getPageSize(), nullptr, - sys::Memory::MF_READ | - sys::Memory::MF_WRITE, EC)); + sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory( + sys::Process::getPageSize(), nullptr, + sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC)); assert(!EC && "Failed to allocate trampoline block"); - unsigned NumTrampolines = - (sys::Process::getPageSize() - TargetT::PointerSize) / + (sys::Process::getPageSize() - TargetT::PointerSize) / TargetT::TrampolineSize; - uint8_t *TrampolineMem = static_cast<uint8_t*>(TrampolineBlock.base()); + uint8_t *TrampolineMem = static_cast<uint8_t *>(TrampolineBlock.base()); TargetT::writeTrampolines(TrampolineMem, ResolverBlock.base(), NumTrampolines); @@ -214,19 +206,18 @@ private: /// @brief Base class for managing collections of named indirect stubs. class IndirectStubsManager { public: - /// @brief Map type for initializing the manager. See init. typedef StringMap<std::pair<TargetAddress, JITSymbolFlags>> StubInitsMap; virtual ~IndirectStubsManager() {} /// @brief Create a single stub with the given name, target address and flags. - virtual std::error_code createStub(StringRef StubName, TargetAddress StubAddr, - JITSymbolFlags StubFlags) = 0; + virtual Error createStub(StringRef StubName, TargetAddress StubAddr, + JITSymbolFlags StubFlags) = 0; /// @brief Create StubInits.size() stubs with the given names, target /// addresses, and flags. - virtual std::error_code createStubs(const StubInitsMap &StubInits) = 0; + virtual Error createStubs(const StubInitsMap &StubInits) = 0; /// @brief Find the stub with the given name. If ExportedStubsOnly is true, /// this will only return a result if the stub's flags indicate that it @@ -237,7 +228,8 @@ public: virtual JITSymbol findPointer(StringRef Name) = 0; /// @brief Change the value of the implementation pointer for the stub. - virtual std::error_code updatePointer(StringRef Name, TargetAddress NewAddr) = 0; + virtual Error updatePointer(StringRef Name, TargetAddress NewAddr) = 0; + private: virtual void anchor(); }; @@ -247,26 +239,25 @@ private: template <typename TargetT> class LocalIndirectStubsManager : public IndirectStubsManager { public: - - std::error_code createStub(StringRef StubName, TargetAddress StubAddr, - JITSymbolFlags StubFlags) override { - if (auto EC = reserveStubs(1)) - return EC; + Error createStub(StringRef StubName, TargetAddress StubAddr, + JITSymbolFlags StubFlags) override { + if (auto Err = reserveStubs(1)) + return Err; createStubInternal(StubName, StubAddr, StubFlags); - return std::error_code(); + return Error::success(); } - std::error_code createStubs(const StubInitsMap &StubInits) override { - if (auto EC = reserveStubs(StubInits.size())) - return EC; + Error createStubs(const StubInitsMap &StubInits) override { + if (auto Err = reserveStubs(StubInits.size())) + return Err; for (auto &Entry : StubInits) createStubInternal(Entry.first(), Entry.second.first, Entry.second.second); - return std::error_code(); + return Error::success(); } JITSymbol findStub(StringRef Name, bool ExportedStubsOnly) override { @@ -277,7 +268,7 @@ public: void *StubAddr = IndirectStubsInfos[Key.first].getStub(Key.second); assert(StubAddr && "Missing stub address"); auto StubTargetAddr = - static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(StubAddr)); + static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(StubAddr)); auto StubSymbol = JITSymbol(StubTargetAddr, I->second.second); if (ExportedStubsOnly && !StubSymbol.isExported()) return nullptr; @@ -292,35 +283,34 @@ public: void *PtrAddr = IndirectStubsInfos[Key.first].getPtr(Key.second); assert(PtrAddr && "Missing pointer address"); auto PtrTargetAddr = - static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(PtrAddr)); + static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(PtrAddr)); return JITSymbol(PtrTargetAddr, I->second.second); } - std::error_code updatePointer(StringRef Name, TargetAddress NewAddr) override { + Error updatePointer(StringRef Name, TargetAddress NewAddr) override { auto I = StubIndexes.find(Name); assert(I != StubIndexes.end() && "No stub pointer for symbol"); auto Key = I->second.first; *IndirectStubsInfos[Key.first].getPtr(Key.second) = - reinterpret_cast<void*>(static_cast<uintptr_t>(NewAddr)); - return std::error_code(); + reinterpret_cast<void *>(static_cast<uintptr_t>(NewAddr)); + return Error::success(); } private: - - std::error_code reserveStubs(unsigned NumStubs) { + Error reserveStubs(unsigned NumStubs) { if (NumStubs <= FreeStubs.size()) - return std::error_code(); + return Error::success(); unsigned NewStubsRequired = NumStubs - FreeStubs.size(); unsigned NewBlockId = IndirectStubsInfos.size(); typename TargetT::IndirectStubsInfo ISI; - if (auto EC = TargetT::emitIndirectStubsBlock(ISI, NewStubsRequired, - nullptr)) - return EC; + if (auto Err = + TargetT::emitIndirectStubsBlock(ISI, NewStubsRequired, nullptr)) + return Err; for (unsigned I = 0; I < ISI.getNumStubs(); ++I) FreeStubs.push_back(std::make_pair(NewBlockId, I)); IndirectStubsInfos.push_back(std::move(ISI)); - return std::error_code(); + return Error::success(); } void createStubInternal(StringRef StubName, TargetAddress InitAddr, @@ -328,7 +318,7 @@ private: auto Key = FreeStubs.back(); FreeStubs.pop_back(); *IndirectStubsInfos[Key.first].getPtr(Key.second) = - reinterpret_cast<void*>(static_cast<uintptr_t>(InitAddr)); + reinterpret_cast<void *>(static_cast<uintptr_t>(InitAddr)); StubIndexes[StubName] = std::make_pair(Key, StubFlags); } @@ -338,17 +328,32 @@ private: StringMap<std::pair<StubKey, JITSymbolFlags>> StubIndexes; }; +/// @brief Create a local compile callback manager. +/// +/// The given target triple will determine the ABI, and the given +/// ErrorHandlerAddress will be used by the resulting compile callback +/// manager if a compile callback fails. +std::unique_ptr<JITCompileCallbackManager> +createLocalCompileCallbackManager(const Triple &T, + TargetAddress ErrorHandlerAddress); + +/// @brief Create a local indriect stubs manager builder. +/// +/// The given target triple will determine the ABI. +std::function<std::unique_ptr<IndirectStubsManager>()> +createLocalIndirectStubsManagerBuilder(const Triple &T); + /// @brief Build a function pointer of FunctionType with the given constant /// address. /// /// Usage example: Turn a trampoline address into a function pointer constant /// for use in a stub. -Constant* createIRTypedAddress(FunctionType &FT, TargetAddress Addr); +Constant *createIRTypedAddress(FunctionType &FT, TargetAddress Addr); /// @brief Create a function pointer with the given type, name, and initializer /// in the given Module. -GlobalVariable* createImplPointer(PointerType &PT, Module &M, - const Twine &Name, Constant *Initializer); +GlobalVariable *createImplPointer(PointerType &PT, Module &M, const Twine &Name, + Constant *Initializer); /// @brief Turn a function declaration into a stub function that makes an /// indirect call using the given function pointer. @@ -373,7 +378,7 @@ void makeAllSymbolsExternallyAccessible(Module &M); /// modules with these utilities, all decls should be cloned (and added to a /// single VMap) before any bodies are moved. This will ensure that references /// between functions all refer to the versions in the new module. -Function* cloneFunctionDecl(Module &Dst, const Function &F, +Function *cloneFunctionDecl(Module &Dst, const Function &F, ValueToValueMapTy *VMap = nullptr); /// @brief Move the body of function 'F' to a cloned function declaration in a @@ -389,7 +394,7 @@ void moveFunctionBody(Function &OrigF, ValueToValueMapTy &VMap, Function *NewF = nullptr); /// @brief Clone a global variable declaration into a new module. -GlobalVariable* cloneGlobalVariableDecl(Module &Dst, const GlobalVariable &GV, +GlobalVariable *cloneGlobalVariableDecl(Module &Dst, const GlobalVariable &GV, ValueToValueMapTy *VMap = nullptr); /// @brief Move global variable GV from its parent module to cloned global @@ -406,7 +411,7 @@ void moveGlobalVariableInitializer(GlobalVariable &OrigGV, GlobalVariable *NewGV = nullptr); /// @brief Clone -GlobalAlias* cloneGlobalAliasDecl(Module &Dst, const GlobalAlias &OrigA, +GlobalAlias *cloneGlobalAliasDecl(Module &Dst, const GlobalAlias &OrigA, ValueToValueMapTy &VMap); } // End namespace orc. diff --git a/include/llvm/ExecutionEngine/Orc/JITSymbol.h b/include/llvm/ExecutionEngine/Orc/JITSymbol.h index 422a3761837c..464417e4e6d5 100644 --- a/include/llvm/ExecutionEngine/Orc/JITSymbol.h +++ b/include/llvm/ExecutionEngine/Orc/JITSymbol.h @@ -15,6 +15,7 @@ #define LLVM_EXECUTIONENGINE_ORC_JITSYMBOL_H #include "llvm/ExecutionEngine/JITSymbolFlags.h" +#include "llvm/ExecutionEngine/RuntimeDyld.h" #include "llvm/Support/DataTypes.h" #include <cassert> #include <functional> @@ -52,6 +53,10 @@ public: JITSymbol(GetAddressFtor GetAddress, JITSymbolFlags Flags) : JITSymbolBase(Flags), GetAddress(std::move(GetAddress)), CachedAddr(0) {} + /// @brief Create a JITSymbol from a RuntimeDyld::SymbolInfo. + JITSymbol(const RuntimeDyld::SymbolInfo &Sym) + : JITSymbolBase(Sym.getFlags()), CachedAddr(Sym.getAddress()) {} + /// @brief Returns true if the symbol exists, false otherwise. explicit operator bool() const { return CachedAddr || GetAddress; } @@ -66,6 +71,11 @@ public: return CachedAddr; } + /// @brief Convert this JITSymbol to a RuntimeDyld::SymbolInfo. + RuntimeDyld::SymbolInfo toRuntimeDyldSymbol() { + return RuntimeDyld::SymbolInfo(getAddress(), getFlags()); + } + private: GetAddressFtor GetAddress; TargetAddress CachedAddr; diff --git a/include/llvm/ExecutionEngine/Orc/LambdaResolver.h b/include/llvm/ExecutionEngine/Orc/LambdaResolver.h index faa23658524f..a42b9d5c29d1 100644 --- a/include/llvm/ExecutionEngine/Orc/LambdaResolver.h +++ b/include/llvm/ExecutionEngine/Orc/LambdaResolver.h @@ -18,42 +18,41 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ExecutionEngine/RuntimeDyld.h" #include <memory> -#include <vector> namespace llvm { namespace orc { -template <typename ExternalLookupFtorT, typename DylibLookupFtorT> +template <typename DylibLookupFtorT, typename ExternalLookupFtorT> class LambdaResolver : public RuntimeDyld::SymbolResolver { public: - LambdaResolver(ExternalLookupFtorT ExternalLookupFtor, - DylibLookupFtorT DylibLookupFtor) - : ExternalLookupFtor(ExternalLookupFtor), - DylibLookupFtor(DylibLookupFtor) {} - - RuntimeDyld::SymbolInfo findSymbol(const std::string &Name) final { - return ExternalLookupFtor(Name); - } + LambdaResolver(DylibLookupFtorT DylibLookupFtor, + ExternalLookupFtorT ExternalLookupFtor) + : DylibLookupFtor(DylibLookupFtor), + ExternalLookupFtor(ExternalLookupFtor) {} RuntimeDyld::SymbolInfo findSymbolInLogicalDylib(const std::string &Name) final { return DylibLookupFtor(Name); } + RuntimeDyld::SymbolInfo findSymbol(const std::string &Name) final { + return ExternalLookupFtor(Name); + } + private: - ExternalLookupFtorT ExternalLookupFtor; DylibLookupFtorT DylibLookupFtor; + ExternalLookupFtorT ExternalLookupFtor; }; -template <typename ExternalLookupFtorT, - typename DylibLookupFtorT> -std::unique_ptr<LambdaResolver<ExternalLookupFtorT, DylibLookupFtorT>> -createLambdaResolver(ExternalLookupFtorT ExternalLookupFtor, - DylibLookupFtorT DylibLookupFtor) { - typedef LambdaResolver<ExternalLookupFtorT, DylibLookupFtorT> LR; - return make_unique<LR>(std::move(ExternalLookupFtor), - std::move(DylibLookupFtor)); +template <typename DylibLookupFtorT, + typename ExternalLookupFtorT> +std::unique_ptr<LambdaResolver<DylibLookupFtorT, ExternalLookupFtorT>> +createLambdaResolver(DylibLookupFtorT DylibLookupFtor, + ExternalLookupFtorT ExternalLookupFtor) { + typedef LambdaResolver<DylibLookupFtorT, ExternalLookupFtorT> LR; + return make_unique<LR>(std::move(DylibLookupFtor), + std::move(ExternalLookupFtor)); } } // End namespace orc. diff --git a/include/llvm/ExecutionEngine/Orc/LazyEmittingLayer.h b/include/llvm/ExecutionEngine/Orc/LazyEmittingLayer.h index a5286ff9adde..c5fb6b847b30 100644 --- a/include/llvm/ExecutionEngine/Orc/LazyEmittingLayer.h +++ b/include/llvm/ExecutionEngine/Orc/LazyEmittingLayer.h @@ -195,13 +195,8 @@ private: for (const auto &M : Ms) { Mangler Mang; - for (const auto &V : M->globals()) - if (auto GV = addGlobalValue(*Symbols, V, Mang, SearchName, - ExportedSymbolsOnly)) - return GV; - - for (const auto &F : *M) - if (auto GV = addGlobalValue(*Symbols, F, Mang, SearchName, + for (const auto &GO : M->global_objects()) + if (auto GV = addGlobalValue(*Symbols, GO, Mang, SearchName, ExportedSymbolsOnly)) return GV; } diff --git a/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h b/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h index 4dc48f114883..a7798d8beb8d 100644 --- a/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h +++ b/include/llvm/ExecutionEngine/Orc/ObjectLinkingLayer.h @@ -26,7 +26,6 @@ namespace orc { class ObjectLinkingLayerBase { protected: - /// @brief Holds a set of objects to be allocated/linked as a unit in the JIT. /// /// An instance of this class will be created for each set of objects added @@ -38,38 +37,31 @@ protected: LinkedObjectSet(const LinkedObjectSet&) = delete; void operator=(const LinkedObjectSet&) = delete; public: - LinkedObjectSet(RuntimeDyld::MemoryManager &MemMgr, - RuntimeDyld::SymbolResolver &Resolver, - bool ProcessAllSections) - : RTDyld(llvm::make_unique<RuntimeDyld>(MemMgr, Resolver)), - State(Raw) { - RTDyld->setProcessAllSections(ProcessAllSections); - } - + LinkedObjectSet() = default; virtual ~LinkedObjectSet() {} - std::unique_ptr<RuntimeDyld::LoadedObjectInfo> - addObject(const object::ObjectFile &Obj) { - return RTDyld->loadObject(Obj); - } - - RuntimeDyld::SymbolInfo getSymbol(StringRef Name) const { - return RTDyld->getSymbol(Name); - } + virtual void finalize() = 0; - bool NeedsFinalization() const { return (State == Raw); } + virtual JITSymbol::GetAddressFtor + getSymbolMaterializer(std::string Name) = 0; - virtual void Finalize() = 0; + virtual void mapSectionAddress(const void *LocalAddress, + TargetAddress TargetAddr) const = 0; - void mapSectionAddress(const void *LocalAddress, TargetAddress TargetAddr) { - assert((State != Finalized) && - "Attempting to remap sections for finalized objects."); - RTDyld->mapSectionAddress(LocalAddress, TargetAddr); + JITSymbol getSymbol(StringRef Name, bool ExportedSymbolsOnly) { + auto SymEntry = SymbolTable.find(Name); + if (SymEntry == SymbolTable.end()) + return nullptr; + if (!SymEntry->second.isExported() && ExportedSymbolsOnly) + return nullptr; + if (!Finalized) + return JITSymbol(getSymbolMaterializer(Name), + SymEntry->second.getFlags()); + return JITSymbol(SymEntry->second); } - protected: - std::unique_ptr<RuntimeDyld> RTDyld; - enum { Raw, Finalizing, Finalized } State; + StringMap<RuntimeDyld::SymbolInfo> SymbolTable; + bool Finalized = false; }; typedef std::list<std::unique_ptr<LinkedObjectSet>> LinkedObjectSetListT; @@ -79,6 +71,7 @@ public: typedef LinkedObjectSetListT::iterator ObjSetHandleT; }; + /// @brief Default (no-op) action to perform when loading objects. class DoNothingOnNotifyLoaded { public: @@ -95,34 +88,126 @@ public: /// symbols. template <typename NotifyLoadedFtor = DoNothingOnNotifyLoaded> class ObjectLinkingLayer : public ObjectLinkingLayerBase { +public: + + /// @brief Functor for receiving finalization notifications. + typedef std::function<void(ObjSetHandleT)> NotifyFinalizedFtor; + private: - template <typename MemoryManagerPtrT, typename SymbolResolverPtrT> + template <typename ObjSetT, typename MemoryManagerPtrT, + typename SymbolResolverPtrT, typename FinalizerFtor> class ConcreteLinkedObjectSet : public LinkedObjectSet { public: - ConcreteLinkedObjectSet(MemoryManagerPtrT MemMgr, + ConcreteLinkedObjectSet(ObjSetT Objects, MemoryManagerPtrT MemMgr, SymbolResolverPtrT Resolver, + FinalizerFtor Finalizer, bool ProcessAllSections) - : LinkedObjectSet(*MemMgr, *Resolver, ProcessAllSections), - MemMgr(std::move(MemMgr)), Resolver(std::move(Resolver)) { } + : MemMgr(std::move(MemMgr)), + PFC(llvm::make_unique<PreFinalizeContents>(std::move(Objects), + std::move(Resolver), + std::move(Finalizer), + ProcessAllSections)) { + buildInitialSymbolTable(PFC->Objects); + } + + void setHandle(ObjSetHandleT H) { + PFC->Handle = H; + } + + void finalize() override { + assert(PFC && "mapSectionAddress called on finalized LinkedObjectSet"); + + RuntimeDyld RTDyld(*MemMgr, *PFC->Resolver); + RTDyld.setProcessAllSections(PFC->ProcessAllSections); + PFC->RTDyld = &RTDyld; - void Finalize() override { - State = Finalizing; - RTDyld->finalizeWithMemoryManagerLocking(); - State = Finalized; + PFC->Finalizer(PFC->Handle, RTDyld, std::move(PFC->Objects), + [&]() { + this->updateSymbolTable(RTDyld); + this->Finalized = true; + }); + + // Release resources. + PFC = nullptr; + } + + JITSymbol::GetAddressFtor getSymbolMaterializer(std::string Name) override { + return + [this, Name]() { + // The symbol may be materialized between the creation of this lambda + // and its execution, so we need to double check. + if (!this->Finalized) + this->finalize(); + return this->getSymbol(Name, false).getAddress(); + }; + } + + void mapSectionAddress(const void *LocalAddress, + TargetAddress TargetAddr) const override { + assert(PFC && "mapSectionAddress called on finalized LinkedObjectSet"); + assert(PFC->RTDyld && "mapSectionAddress called on raw LinkedObjectSet"); + PFC->RTDyld->mapSectionAddress(LocalAddress, TargetAddr); } private: + + void buildInitialSymbolTable(const ObjSetT &Objects) { + for (const auto &Obj : Objects) + for (auto &Symbol : getObject(*Obj).symbols()) { + if (Symbol.getFlags() & object::SymbolRef::SF_Undefined) + continue; + Expected<StringRef> SymbolName = Symbol.getName(); + // FIXME: Raise an error for bad symbols. + if (!SymbolName) { + consumeError(SymbolName.takeError()); + continue; + } + auto Flags = JITSymbol::flagsFromObjectSymbol(Symbol); + SymbolTable.insert( + std::make_pair(*SymbolName, RuntimeDyld::SymbolInfo(0, Flags))); + } + } + + void updateSymbolTable(const RuntimeDyld &RTDyld) { + for (auto &SymEntry : SymbolTable) + SymEntry.second = RTDyld.getSymbol(SymEntry.first()); + } + + // Contains the information needed prior to finalization: the object files, + // memory manager, resolver, and flags needed for RuntimeDyld. + struct PreFinalizeContents { + PreFinalizeContents(ObjSetT Objects, SymbolResolverPtrT Resolver, + FinalizerFtor Finalizer, bool ProcessAllSections) + : Objects(std::move(Objects)), Resolver(std::move(Resolver)), + Finalizer(std::move(Finalizer)), + ProcessAllSections(ProcessAllSections) {} + + ObjSetT Objects; + SymbolResolverPtrT Resolver; + FinalizerFtor Finalizer; + bool ProcessAllSections; + ObjSetHandleT Handle; + RuntimeDyld *RTDyld; + }; + MemoryManagerPtrT MemMgr; - SymbolResolverPtrT Resolver; + std::unique_ptr<PreFinalizeContents> PFC; }; - template <typename MemoryManagerPtrT, typename SymbolResolverPtrT> - std::unique_ptr<LinkedObjectSet> - createLinkedObjectSet(MemoryManagerPtrT MemMgr, SymbolResolverPtrT Resolver, + template <typename ObjSetT, typename MemoryManagerPtrT, + typename SymbolResolverPtrT, typename FinalizerFtor> + std::unique_ptr< + ConcreteLinkedObjectSet<ObjSetT, MemoryManagerPtrT, + SymbolResolverPtrT, FinalizerFtor>> + createLinkedObjectSet(ObjSetT Objects, MemoryManagerPtrT MemMgr, + SymbolResolverPtrT Resolver, + FinalizerFtor Finalizer, bool ProcessAllSections) { - typedef ConcreteLinkedObjectSet<MemoryManagerPtrT, SymbolResolverPtrT> LOS; - return llvm::make_unique<LOS>(std::move(MemMgr), std::move(Resolver), + typedef ConcreteLinkedObjectSet<ObjSetT, MemoryManagerPtrT, + SymbolResolverPtrT, FinalizerFtor> LOS; + return llvm::make_unique<LOS>(std::move(Objects), std::move(MemMgr), + std::move(Resolver), std::move(Finalizer), ProcessAllSections); } @@ -133,9 +218,6 @@ public: typedef std::vector<std::unique_ptr<RuntimeDyld::LoadedObjectInfo>> LoadedObjInfoList; - /// @brief Functor for receiving finalization notifications. - typedef std::function<void(ObjSetHandleT)> NotifyFinalizedFtor; - /// @brief Construct an ObjectLinkingLayer with the given NotifyLoaded, /// and NotifyFinalized functors. ObjectLinkingLayer( @@ -158,33 +240,44 @@ public: /// @brief Add a set of objects (or archives) that will be treated as a unit /// for the purposes of symbol lookup and memory management. /// - /// @return A pair containing (1) A handle that can be used to free the memory - /// allocated for the objects, and (2) a LoadedObjInfoList containing - /// one LoadedObjInfo instance for each object at the corresponding - /// index in the Objects list. - /// - /// This version of this method allows the client to pass in an - /// RTDyldMemoryManager instance that will be used to allocate memory and look - /// up external symbol addresses for the given objects. + /// @return A handle that can be used to refer to the loaded objects (for + /// symbol searching, finalization, freeing memory, etc.). template <typename ObjSetT, typename MemoryManagerPtrT, typename SymbolResolverPtrT> - ObjSetHandleT addObjectSet(const ObjSetT &Objects, + ObjSetHandleT addObjectSet(ObjSetT Objects, MemoryManagerPtrT MemMgr, SymbolResolverPtrT Resolver) { - ObjSetHandleT Handle = - LinkedObjSetList.insert( - LinkedObjSetList.end(), - createLinkedObjectSet(std::move(MemMgr), std::move(Resolver), - ProcessAllSections)); - LinkedObjectSet &LOS = **Handle; - LoadedObjInfoList LoadedObjInfos; + auto Finalizer = [&](ObjSetHandleT H, RuntimeDyld &RTDyld, + const ObjSetT &Objs, + std::function<void()> LOSHandleLoad) { + LoadedObjInfoList LoadedObjInfos; + + for (auto &Obj : Objs) + LoadedObjInfos.push_back(RTDyld.loadObject(this->getObject(*Obj))); + + LOSHandleLoad(); - for (auto &Obj : Objects) - LoadedObjInfos.push_back(LOS.addObject(*Obj)); + this->NotifyLoaded(H, Objs, LoadedObjInfos); - NotifyLoaded(Handle, Objects, LoadedObjInfos); + RTDyld.finalizeWithMemoryManagerLocking(); + + if (this->NotifyFinalized) + this->NotifyFinalized(H); + }; + + auto LOS = + createLinkedObjectSet(std::move(Objects), std::move(MemMgr), + std::move(Resolver), std::move(Finalizer), + ProcessAllSections); + // LOS is an owning-ptr. Keep a non-owning one so that we can set the handle + // below. + auto *LOSPtr = LOS.get(); + + ObjSetHandleT Handle = LinkedObjSetList.insert(LinkedObjSetList.end(), + std::move(LOS)); + LOSPtr->setHandle(Handle); return Handle; } @@ -224,33 +317,7 @@ public: /// given object set. JITSymbol findSymbolIn(ObjSetHandleT H, StringRef Name, bool ExportedSymbolsOnly) { - if (auto Sym = (*H)->getSymbol(Name)) { - if (Sym.isExported() || !ExportedSymbolsOnly) { - auto Addr = Sym.getAddress(); - auto Flags = Sym.getFlags(); - if (!(*H)->NeedsFinalization()) { - // If this instance has already been finalized then we can just return - // the address. - return JITSymbol(Addr, Flags); - } else { - // If this instance needs finalization return a functor that will do - // it. The functor still needs to double-check whether finalization is - // required, in case someone else finalizes this set before the - // functor is called. - auto GetAddress = - [this, Addr, H]() { - if ((*H)->NeedsFinalization()) { - (*H)->Finalize(); - if (NotifyFinalized) - NotifyFinalized(H); - } - return Addr; - }; - return JITSymbol(std::move(GetAddress), Flags); - } - } - } - return nullptr; + return (*H)->getSymbol(Name, ExportedSymbolsOnly); } /// @brief Map section addresses for the objects associated with the handle H. @@ -263,12 +330,21 @@ public: /// given handle. /// @param H Handle for object set to emit/finalize. void emitAndFinalize(ObjSetHandleT H) { - (*H)->Finalize(); - if (NotifyFinalized) - NotifyFinalized(H); + (*H)->finalize(); } private: + + static const object::ObjectFile& getObject(const object::ObjectFile &Obj) { + return Obj; + } + + template <typename ObjT> + static const object::ObjectFile& + getObject(const object::OwningBinary<ObjT> &Obj) { + return *Obj.getBinary(); + } + LinkedObjectSetListT LinkedObjSetList; NotifyLoadedFtor NotifyLoaded; NotifyFinalizedFtor NotifyFinalized; diff --git a/include/llvm/ExecutionEngine/Orc/ObjectTransformLayer.h b/include/llvm/ExecutionEngine/Orc/ObjectTransformLayer.h index f96e83ed5a1a..2ffe71c94356 100644 --- a/include/llvm/ExecutionEngine/Orc/ObjectTransformLayer.h +++ b/include/llvm/ExecutionEngine/Orc/ObjectTransformLayer.h @@ -42,13 +42,13 @@ public: /// @return A handle for the added objects. template <typename ObjSetT, typename MemoryManagerPtrT, typename SymbolResolverPtrT> - ObjSetHandleT addObjectSet(ObjSetT &Objects, MemoryManagerPtrT MemMgr, + ObjSetHandleT addObjectSet(ObjSetT Objects, MemoryManagerPtrT MemMgr, SymbolResolverPtrT Resolver) { for (auto I = Objects.begin(), E = Objects.end(); I != E; ++I) *I = Transform(std::move(*I)); - return BaseLayer.addObjectSet(Objects, std::move(MemMgr), + return BaseLayer.addObjectSet(std::move(Objects), std::move(MemMgr), std::move(Resolver)); } diff --git a/include/llvm/ExecutionEngine/Orc/OrcABISupport.h b/include/llvm/ExecutionEngine/Orc/OrcABISupport.h new file mode 100644 index 000000000000..4a8d0b0b801c --- /dev/null +++ b/include/llvm/ExecutionEngine/Orc/OrcABISupport.h @@ -0,0 +1,232 @@ +//===-------------- OrcABISupport.h - ABI support code ---------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// ABI specific code for Orc, e.g. callback assembly. +// +// ABI classes should be part of the JIT *target* process, not the host +// process (except where you're doing hosted JITing and the two are one and the +// same). +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_ORCABISUPPORT_H +#define LLVM_EXECUTIONENGINE_ORC_ORCABISUPPORT_H + +#include "IndirectionUtils.h" +#include "llvm/Support/Memory.h" +#include "llvm/Support/Process.h" + +namespace llvm { +namespace orc { + +/// Generic ORC ABI support. +/// +/// This class can be substituted as the target architecure support class for +/// ORC templates that require one (e.g. IndirectStubsManagers). It does not +/// support lazy JITing however, and any attempt to use that functionality +/// will result in execution of an llvm_unreachable. +class OrcGenericABI { +public: + static const unsigned PointerSize = sizeof(uintptr_t); + static const unsigned TrampolineSize = 1; + static const unsigned ResolverCodeSize = 1; + + typedef TargetAddress (*JITReentryFn)(void *CallbackMgr, void *TrampolineId); + + static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry, + void *CallbackMgr) { + llvm_unreachable("writeResolverCode is not supported by the generic host " + "support class"); + } + + static void writeTrampolines(uint8_t *TrampolineMem, void *ResolverAddr, + unsigned NumTrampolines) { + llvm_unreachable("writeTrampolines is not supported by the generic host " + "support class"); + } + + class IndirectStubsInfo { + public: + const static unsigned StubSize = 1; + unsigned getNumStubs() const { llvm_unreachable("Not supported"); } + void *getStub(unsigned Idx) const { llvm_unreachable("Not supported"); } + void **getPtr(unsigned Idx) const { llvm_unreachable("Not supported"); } + }; + + static Error emitIndirectStubsBlock(IndirectStubsInfo &StubsInfo, + unsigned MinStubs, void *InitialPtrVal) { + llvm_unreachable("emitIndirectStubsBlock is not supported by the generic " + "host support class"); + } +}; + +/// @brief Provide information about stub blocks generated by the +/// makeIndirectStubsBlock function. +template <unsigned StubSizeVal> class GenericIndirectStubsInfo { +public: + const static unsigned StubSize = StubSizeVal; + + GenericIndirectStubsInfo() : NumStubs(0) {} + GenericIndirectStubsInfo(unsigned NumStubs, sys::OwningMemoryBlock StubsMem) + : NumStubs(NumStubs), StubsMem(std::move(StubsMem)) {} + GenericIndirectStubsInfo(GenericIndirectStubsInfo &&Other) + : NumStubs(Other.NumStubs), StubsMem(std::move(Other.StubsMem)) { + Other.NumStubs = 0; + } + GenericIndirectStubsInfo &operator=(GenericIndirectStubsInfo &&Other) { + NumStubs = Other.NumStubs; + Other.NumStubs = 0; + StubsMem = std::move(Other.StubsMem); + return *this; + } + + /// @brief Number of stubs in this block. + unsigned getNumStubs() const { return NumStubs; } + + /// @brief Get a pointer to the stub at the given index, which must be in + /// the range 0 .. getNumStubs() - 1. + void *getStub(unsigned Idx) const { + return static_cast<char *>(StubsMem.base()) + Idx * StubSize; + } + + /// @brief Get a pointer to the implementation-pointer at the given index, + /// which must be in the range 0 .. getNumStubs() - 1. + void **getPtr(unsigned Idx) const { + char *PtrsBase = static_cast<char *>(StubsMem.base()) + NumStubs * StubSize; + return reinterpret_cast<void **>(PtrsBase) + Idx; + } + +private: + unsigned NumStubs; + sys::OwningMemoryBlock StubsMem; +}; + +class OrcAArch64 { +public: + static const unsigned PointerSize = 8; + static const unsigned TrampolineSize = 12; + static const unsigned ResolverCodeSize = 0x120; + + typedef GenericIndirectStubsInfo<8> IndirectStubsInfo; + + typedef TargetAddress (*JITReentryFn)(void *CallbackMgr, void *TrampolineId); + + /// @brief Write the resolver code into the given memory. The user is be + /// responsible for allocating the memory and setting permissions. + static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry, + void *CallbackMgr); + + /// @brief Write the requsted number of trampolines into the given memory, + /// which must be big enough to hold 1 pointer, plus NumTrampolines + /// trampolines. + static void writeTrampolines(uint8_t *TrampolineMem, void *ResolverAddr, + unsigned NumTrampolines); + + /// @brief Emit at least MinStubs worth of indirect call stubs, rounded out to + /// the nearest page size. + /// + /// E.g. Asking for 4 stubs on x86-64, where stubs are 8-bytes, with 4k + /// pages will return a block of 512 stubs (4096 / 8 = 512). Asking for 513 + /// will return a block of 1024 (2-pages worth). + static Error emitIndirectStubsBlock(IndirectStubsInfo &StubsInfo, + unsigned MinStubs, void *InitialPtrVal); +}; + +/// @brief X86_64 code that's common to all ABIs. +/// +/// X86_64 supports lazy JITing. +class OrcX86_64_Base { +public: + static const unsigned PointerSize = 8; + static const unsigned TrampolineSize = 8; + + typedef GenericIndirectStubsInfo<8> IndirectStubsInfo; + + /// @brief Write the requsted number of trampolines into the given memory, + /// which must be big enough to hold 1 pointer, plus NumTrampolines + /// trampolines. + static void writeTrampolines(uint8_t *TrampolineMem, void *ResolverAddr, + unsigned NumTrampolines); + + /// @brief Emit at least MinStubs worth of indirect call stubs, rounded out to + /// the nearest page size. + /// + /// E.g. Asking for 4 stubs on x86-64, where stubs are 8-bytes, with 4k + /// pages will return a block of 512 stubs (4096 / 8 = 512). Asking for 513 + /// will return a block of 1024 (2-pages worth). + static Error emitIndirectStubsBlock(IndirectStubsInfo &StubsInfo, + unsigned MinStubs, void *InitialPtrVal); +}; + +/// @brief X86_64 support for SysV ABI (Linux, MacOSX). +/// +/// X86_64_SysV supports lazy JITing. +class OrcX86_64_SysV : public OrcX86_64_Base { +public: + static const unsigned ResolverCodeSize = 0x6C; + typedef TargetAddress(*JITReentryFn)(void *CallbackMgr, void *TrampolineId); + + /// @brief Write the resolver code into the given memory. The user is be + /// responsible for allocating the memory and setting permissions. + static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry, + void *CallbackMgr); +}; + +/// @brief X86_64 support for Win32. +/// +/// X86_64_Win32 supports lazy JITing. +class OrcX86_64_Win32 : public OrcX86_64_Base { +public: + static const unsigned ResolverCodeSize = 0x74; + typedef TargetAddress(*JITReentryFn)(void *CallbackMgr, void *TrampolineId); + + /// @brief Write the resolver code into the given memory. The user is be + /// responsible for allocating the memory and setting permissions. + static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry, + void *CallbackMgr); +}; + +/// @brief I386 support. +/// +/// I386 supports lazy JITing. +class OrcI386 { +public: + static const unsigned PointerSize = 4; + static const unsigned TrampolineSize = 8; + static const unsigned ResolverCodeSize = 0x4a; + + typedef GenericIndirectStubsInfo<8> IndirectStubsInfo; + + typedef TargetAddress (*JITReentryFn)(void *CallbackMgr, void *TrampolineId); + + /// @brief Write the resolver code into the given memory. The user is be + /// responsible for allocating the memory and setting permissions. + static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry, + void *CallbackMgr); + + /// @brief Write the requsted number of trampolines into the given memory, + /// which must be big enough to hold 1 pointer, plus NumTrampolines + /// trampolines. + static void writeTrampolines(uint8_t *TrampolineMem, void *ResolverAddr, + unsigned NumTrampolines); + + /// @brief Emit at least MinStubs worth of indirect call stubs, rounded out to + /// the nearest page size. + /// + /// E.g. Asking for 4 stubs on i386, where stubs are 8-bytes, with 4k + /// pages will return a block of 512 stubs (4096 / 8 = 512). Asking for 513 + /// will return a block of 1024 (2-pages worth). + static Error emitIndirectStubsBlock(IndirectStubsInfo &StubsInfo, + unsigned MinStubs, void *InitialPtrVal); +}; + +} // End namespace orc. +} // End namespace llvm. + +#endif // LLVM_EXECUTIONENGINE_ORC_ORCABISUPPORT_H diff --git a/include/llvm/ExecutionEngine/Orc/OrcArchitectureSupport.h b/include/llvm/ExecutionEngine/Orc/OrcArchitectureSupport.h deleted file mode 100644 index 1b0488bcf00d..000000000000 --- a/include/llvm/ExecutionEngine/Orc/OrcArchitectureSupport.h +++ /dev/null @@ -1,148 +0,0 @@ -//===-- OrcArchitectureSupport.h - Architecture support code ---*- C++ -*-===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -// -// Architecture specific code for Orc, e.g. callback assembly. -// -// Architecture classes should be part of the JIT *target* process, not the host -// process (except where you're doing hosted JITing and the two are one and the -// same). -// -//===----------------------------------------------------------------------===// - -#ifndef LLVM_EXECUTIONENGINE_ORC_ORCARCHITECTURESUPPORT_H -#define LLVM_EXECUTIONENGINE_ORC_ORCARCHITECTURESUPPORT_H - -#include "IndirectionUtils.h" -#include "llvm/Support/Memory.h" -#include "llvm/Support/Process.h" - -namespace llvm { -namespace orc { - -/// Generic ORC Architecture support. -/// -/// This class can be substituted as the target architecure support class for -/// ORC templates that require one (e.g. IndirectStubsManagers). It does not -/// support lazy JITing however, and any attempt to use that functionality -/// will result in execution of an llvm_unreachable. -class OrcGenericArchitecture { -public: - static const unsigned PointerSize = sizeof(uintptr_t); - static const unsigned TrampolineSize = 1; - static const unsigned ResolverCodeSize = 1; - - typedef TargetAddress (*JITReentryFn)(void *CallbackMgr, void *TrampolineId); - - static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry, - void *CallbackMgr) { - llvm_unreachable("writeResolverCode is not supported by the generic host " - "support class"); - } - - static void writeTrampolines(uint8_t *TrampolineMem, void *ResolverAddr, - unsigned NumTrampolines) { - llvm_unreachable("writeTrampolines is not supported by the generic host " - "support class"); - } - - class IndirectStubsInfo { - public: - const static unsigned StubSize = 1; - unsigned getNumStubs() const { llvm_unreachable("Not supported"); } - void *getStub(unsigned Idx) const { llvm_unreachable("Not supported"); } - void **getPtr(unsigned Idx) const { llvm_unreachable("Not supported"); } - }; - - static std::error_code emitIndirectStubsBlock(IndirectStubsInfo &StubsInfo, - unsigned MinStubs, - void *InitialPtrVal) { - llvm_unreachable("emitIndirectStubsBlock is not supported by the generic " - "host support class"); - } -}; - -/// @brief X86_64 support. -/// -/// X86_64 supports lazy JITing. -class OrcX86_64 { -public: - static const unsigned PointerSize = 8; - static const unsigned TrampolineSize = 8; - static const unsigned ResolverCodeSize = 0x78; - - typedef TargetAddress (*JITReentryFn)(void *CallbackMgr, void *TrampolineId); - - /// @brief Write the resolver code into the given memory. The user is be - /// responsible for allocating the memory and setting permissions. - static void writeResolverCode(uint8_t *ResolveMem, JITReentryFn Reentry, - void *CallbackMgr); - - /// @brief Write the requsted number of trampolines into the given memory, - /// which must be big enough to hold 1 pointer, plus NumTrampolines - /// trampolines. - static void writeTrampolines(uint8_t *TrampolineMem, void *ResolverAddr, - unsigned NumTrampolines); - - /// @brief Provide information about stub blocks generated by the - /// makeIndirectStubsBlock function. - class IndirectStubsInfo { - friend class OrcX86_64; - - public: - const static unsigned StubSize = 8; - - IndirectStubsInfo() : NumStubs(0) {} - IndirectStubsInfo(IndirectStubsInfo &&Other) - : NumStubs(Other.NumStubs), StubsMem(std::move(Other.StubsMem)) { - Other.NumStubs = 0; - } - IndirectStubsInfo &operator=(IndirectStubsInfo &&Other) { - NumStubs = Other.NumStubs; - Other.NumStubs = 0; - StubsMem = std::move(Other.StubsMem); - return *this; - } - - /// @brief Number of stubs in this block. - unsigned getNumStubs() const { return NumStubs; } - - /// @brief Get a pointer to the stub at the given index, which must be in - /// the range 0 .. getNumStubs() - 1. - void *getStub(unsigned Idx) const { - return static_cast<uint64_t *>(StubsMem.base()) + Idx; - } - - /// @brief Get a pointer to the implementation-pointer at the given index, - /// which must be in the range 0 .. getNumStubs() - 1. - void **getPtr(unsigned Idx) const { - char *PtrsBase = - static_cast<char *>(StubsMem.base()) + NumStubs * StubSize; - return reinterpret_cast<void **>(PtrsBase) + Idx; - } - - private: - unsigned NumStubs; - sys::OwningMemoryBlock StubsMem; - }; - - /// @brief Emit at least MinStubs worth of indirect call stubs, rounded out to - /// the nearest page size. - /// - /// E.g. Asking for 4 stubs on x86-64, where stubs are 8-bytes, with 4k - /// pages will return a block of 512 stubs (4096 / 8 = 512). Asking for 513 - /// will return a block of 1024 (2-pages worth). - static std::error_code emitIndirectStubsBlock(IndirectStubsInfo &StubsInfo, - unsigned MinStubs, - void *InitialPtrVal); -}; - -} // End namespace orc. -} // End namespace llvm. - -#endif // LLVM_EXECUTIONENGINE_ORC_ORCARCHITECTURESUPPORT_H diff --git a/include/llvm/ExecutionEngine/Orc/OrcError.h b/include/llvm/ExecutionEngine/Orc/OrcError.h index 48f35d6b39be..1b3f25fae162 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcError.h +++ b/include/llvm/ExecutionEngine/Orc/OrcError.h @@ -14,6 +14,7 @@ #ifndef LLVM_EXECUTIONENGINE_ORC_ORCERROR_H #define LLVM_EXECUTIONENGINE_ORC_ORCERROR_H +#include "llvm/Support/Error.h" #include <system_error> namespace llvm { @@ -26,10 +27,11 @@ enum class OrcErrorCode : int { RemoteMProtectAddrUnrecognized, RemoteIndirectStubsOwnerDoesNotExist, RemoteIndirectStubsOwnerIdAlreadyInUse, - UnexpectedRPCCall + UnexpectedRPCCall, + UnexpectedRPCResponse, }; -std::error_code orcError(OrcErrorCode ErrCode); +Error orcError(OrcErrorCode ErrCode); } // End namespace orc. } // End namespace llvm. diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h index d7640b8e8b5f..5c867e7e7fd4 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetClient.h @@ -36,6 +36,23 @@ namespace remote { template <typename ChannelT> class OrcRemoteTargetClient : public OrcRemoteTargetRPCAPI { public: + // FIXME: Remove move/copy ops once MSVC supports synthesizing move ops. + + OrcRemoteTargetClient(const OrcRemoteTargetClient &) = delete; + OrcRemoteTargetClient &operator=(const OrcRemoteTargetClient &) = delete; + + OrcRemoteTargetClient(OrcRemoteTargetClient &&Other) + : Channel(Other.Channel), ExistingError(std::move(Other.ExistingError)), + RemoteTargetTriple(std::move(Other.RemoteTargetTriple)), + RemotePointerSize(std::move(Other.RemotePointerSize)), + RemotePageSize(std::move(Other.RemotePageSize)), + RemoteTrampolineSize(std::move(Other.RemoteTrampolineSize)), + RemoteIndirectStubSize(std::move(Other.RemoteIndirectStubSize)), + AllocatorIds(std::move(Other.AllocatorIds)), + IndirectStubOwnerIds(std::move(Other.IndirectStubOwnerIds)) {} + + OrcRemoteTargetClient &operator=(OrcRemoteTargetClient &&) = delete; + /// Remote memory manager. class RCMemoryManager : public RuntimeDyld::MemoryManager { public: @@ -57,7 +74,7 @@ public: return *this; } - ~RCMemoryManager() { + ~RCMemoryManager() override { Client.destroyRemoteAllocator(Id); DEBUG(dbgs() << "Destroyed remote allocator " << Id << "\n"); } @@ -105,11 +122,13 @@ public: DEBUG(dbgs() << "Allocator " << Id << " reserved:\n"); if (CodeSize != 0) { - std::error_code EC = Client.reserveMem(Unmapped.back().RemoteCodeAddr, - Id, CodeSize, CodeAlign); - // FIXME; Add error to poll. - assert(!EC && "Failed reserving remote memory."); - (void)EC; + if (auto AddrOrErr = Client.reserveMem(Id, CodeSize, CodeAlign)) + Unmapped.back().RemoteCodeAddr = *AddrOrErr; + else { + // FIXME; Add error to poll. + assert(!AddrOrErr.takeError() && "Failed reserving remote memory."); + } + DEBUG(dbgs() << " code: " << format("0x%016x", Unmapped.back().RemoteCodeAddr) << " (" << CodeSize << " bytes, alignment " << CodeAlign @@ -117,11 +136,13 @@ public: } if (RODataSize != 0) { - std::error_code EC = Client.reserveMem(Unmapped.back().RemoteRODataAddr, - Id, RODataSize, RODataAlign); - // FIXME; Add error to poll. - assert(!EC && "Failed reserving remote memory."); - (void)EC; + if (auto AddrOrErr = Client.reserveMem(Id, RODataSize, RODataAlign)) + Unmapped.back().RemoteRODataAddr = *AddrOrErr; + else { + // FIXME; Add error to poll. + assert(!AddrOrErr.takeError() && "Failed reserving remote memory."); + } + DEBUG(dbgs() << " ro-data: " << format("0x%016x", Unmapped.back().RemoteRODataAddr) << " (" << RODataSize << " bytes, alignment " @@ -129,11 +150,13 @@ public: } if (RWDataSize != 0) { - std::error_code EC = Client.reserveMem(Unmapped.back().RemoteRWDataAddr, - Id, RWDataSize, RWDataAlign); - // FIXME; Add error to poll. - assert(!EC && "Failed reserving remote memory."); - (void)EC; + if (auto AddrOrErr = Client.reserveMem(Id, RWDataSize, RWDataAlign)) + Unmapped.back().RemoteRWDataAddr = *AddrOrErr; + else { + // FIXME; Add error to poll. + assert(!AddrOrErr.takeError() && "Failed reserving remote memory."); + } + DEBUG(dbgs() << " rw-data: " << format("0x%016x", Unmapped.back().RemoteRWDataAddr) << " (" << RWDataSize << " bytes, alignment " @@ -144,10 +167,18 @@ public: bool needsToReserveAllocationSpace() override { return true; } void registerEHFrames(uint8_t *Addr, uint64_t LoadAddr, - size_t Size) override {} + size_t Size) override { + UnfinalizedEHFrames.push_back( + std::make_pair(LoadAddr, static_cast<uint32_t>(Size))); + } - void deregisterEHFrames(uint8_t *addr, uint64_t LoadAddr, - size_t Size) override {} + void deregisterEHFrames(uint8_t *Addr, uint64_t LoadAddr, + size_t Size) override { + auto Err = Client.deregisterEHFrames(LoadAddr, Size); + // FIXME: Add error poll. + assert(!Err && "Failed to register remote EH frames."); + (void)Err; + } void notifyObjectLoaded(RuntimeDyld &Dyld, const object::ObjectFile &Obj) override { @@ -156,7 +187,7 @@ public: { TargetAddress NextCodeAddr = ObjAllocs.RemoteCodeAddr; for (auto &Alloc : ObjAllocs.CodeAllocs) { - NextCodeAddr = RoundUpToAlignment(NextCodeAddr, Alloc.getAlign()); + NextCodeAddr = alignTo(NextCodeAddr, Alloc.getAlign()); Dyld.mapSectionAddress(Alloc.getLocalAddress(), NextCodeAddr); DEBUG(dbgs() << " code: " << static_cast<void *>(Alloc.getLocalAddress()) @@ -168,8 +199,7 @@ public: { TargetAddress NextRODataAddr = ObjAllocs.RemoteRODataAddr; for (auto &Alloc : ObjAllocs.RODataAllocs) { - NextRODataAddr = - RoundUpToAlignment(NextRODataAddr, Alloc.getAlign()); + NextRODataAddr = alignTo(NextRODataAddr, Alloc.getAlign()); Dyld.mapSectionAddress(Alloc.getLocalAddress(), NextRODataAddr); DEBUG(dbgs() << " ro-data: " << static_cast<void *>(Alloc.getLocalAddress()) @@ -182,8 +212,7 @@ public: { TargetAddress NextRWDataAddr = ObjAllocs.RemoteRWDataAddr; for (auto &Alloc : ObjAllocs.RWDataAllocs) { - NextRWDataAddr = - RoundUpToAlignment(NextRWDataAddr, Alloc.getAlign()); + NextRWDataAddr = alignTo(NextRWDataAddr, Alloc.getAlign()); Dyld.mapSectionAddress(Alloc.getLocalAddress(), NextRWDataAddr); DEBUG(dbgs() << " rw-data: " << static_cast<void *>(Alloc.getLocalAddress()) @@ -208,15 +237,35 @@ public: << static_cast<void *>(Alloc.getLocalAddress()) << " -> " << format("0x%016x", Alloc.getRemoteAddress()) << " (" << Alloc.getSize() << " bytes)\n"); - Client.writeMem(Alloc.getRemoteAddress(), Alloc.getLocalAddress(), - Alloc.getSize()); + if (auto Err = + Client.writeMem(Alloc.getRemoteAddress(), + Alloc.getLocalAddress(), Alloc.getSize())) { + // FIXME: Replace this once finalizeMemory can return an Error. + handleAllErrors(std::move(Err), [&](ErrorInfoBase &EIB) { + if (ErrMsg) { + raw_string_ostream ErrOut(*ErrMsg); + EIB.log(ErrOut); + } + }); + return true; + } } if (ObjAllocs.RemoteCodeAddr) { DEBUG(dbgs() << " setting R-X permissions on code block: " << format("0x%016x", ObjAllocs.RemoteCodeAddr) << "\n"); - Client.setProtections(Id, ObjAllocs.RemoteCodeAddr, - sys::Memory::MF_READ | sys::Memory::MF_EXEC); + if (auto Err = Client.setProtections(Id, ObjAllocs.RemoteCodeAddr, + sys::Memory::MF_READ | + sys::Memory::MF_EXEC)) { + // FIXME: Replace this once finalizeMemory can return an Error. + handleAllErrors(std::move(Err), [&](ErrorInfoBase &EIB) { + if (ErrMsg) { + raw_string_ostream ErrOut(*ErrMsg); + EIB.log(ErrOut); + } + }); + return true; + } } for (auto &Alloc : ObjAllocs.RODataAllocs) { @@ -224,16 +273,35 @@ public: << static_cast<void *>(Alloc.getLocalAddress()) << " -> " << format("0x%016x", Alloc.getRemoteAddress()) << " (" << Alloc.getSize() << " bytes)\n"); - Client.writeMem(Alloc.getRemoteAddress(), Alloc.getLocalAddress(), - Alloc.getSize()); + if (auto Err = + Client.writeMem(Alloc.getRemoteAddress(), + Alloc.getLocalAddress(), Alloc.getSize())) { + // FIXME: Replace this once finalizeMemory can return an Error. + handleAllErrors(std::move(Err), [&](ErrorInfoBase &EIB) { + if (ErrMsg) { + raw_string_ostream ErrOut(*ErrMsg); + EIB.log(ErrOut); + } + }); + return true; + } } if (ObjAllocs.RemoteRODataAddr) { DEBUG(dbgs() << " setting R-- permissions on ro-data block: " << format("0x%016x", ObjAllocs.RemoteRODataAddr) << "\n"); - Client.setProtections(Id, ObjAllocs.RemoteRODataAddr, - sys::Memory::MF_READ); + if (auto Err = Client.setProtections(Id, ObjAllocs.RemoteRODataAddr, + sys::Memory::MF_READ)) { + // FIXME: Replace this once finalizeMemory can return an Error. + handleAllErrors(std::move(Err), [&](ErrorInfoBase &EIB) { + if (ErrMsg) { + raw_string_ostream ErrOut(*ErrMsg); + EIB.log(ErrOut); + } + }); + return false; + } } for (auto &Alloc : ObjAllocs.RWDataAllocs) { @@ -241,20 +309,54 @@ public: << static_cast<void *>(Alloc.getLocalAddress()) << " -> " << format("0x%016x", Alloc.getRemoteAddress()) << " (" << Alloc.getSize() << " bytes)\n"); - Client.writeMem(Alloc.getRemoteAddress(), Alloc.getLocalAddress(), - Alloc.getSize()); + if (auto Err = + Client.writeMem(Alloc.getRemoteAddress(), + Alloc.getLocalAddress(), Alloc.getSize())) { + // FIXME: Replace this once finalizeMemory can return an Error. + handleAllErrors(std::move(Err), [&](ErrorInfoBase &EIB) { + if (ErrMsg) { + raw_string_ostream ErrOut(*ErrMsg); + EIB.log(ErrOut); + } + }); + return false; + } } if (ObjAllocs.RemoteRWDataAddr) { DEBUG(dbgs() << " setting RW- permissions on rw-data block: " << format("0x%016x", ObjAllocs.RemoteRWDataAddr) << "\n"); - Client.setProtections(Id, ObjAllocs.RemoteRWDataAddr, - sys::Memory::MF_READ | sys::Memory::MF_WRITE); + if (auto Err = Client.setProtections(Id, ObjAllocs.RemoteRWDataAddr, + sys::Memory::MF_READ | + sys::Memory::MF_WRITE)) { + // FIXME: Replace this once finalizeMemory can return an Error. + handleAllErrors(std::move(Err), [&](ErrorInfoBase &EIB) { + if (ErrMsg) { + raw_string_ostream ErrOut(*ErrMsg); + EIB.log(ErrOut); + } + }); + return false; + } } } Unfinalized.clear(); + for (auto &EHFrame : UnfinalizedEHFrames) { + if (auto Err = Client.registerEHFrames(EHFrame.first, EHFrame.second)) { + // FIXME: Replace this once finalizeMemory can return an Error. + handleAllErrors(std::move(Err), [&](ErrorInfoBase &EIB) { + if (ErrMsg) { + raw_string_ostream ErrOut(*ErrMsg); + EIB.log(ErrOut); + } + }); + return false; + } + } + UnfinalizedEHFrames.clear(); + return false; } @@ -262,8 +364,7 @@ public: class Alloc { public: Alloc(uint64_t Size, unsigned Align) - : Size(Size), Align(Align), Contents(new char[Size + Align - 1]), - RemoteAddr(0) {} + : Size(Size), Align(Align), Contents(new char[Size + Align - 1]) {} Alloc(Alloc &&Other) : Size(std::move(Other.Size)), Align(std::move(Other.Align)), @@ -284,7 +385,7 @@ public: char *getLocalAddress() const { uintptr_t LocalAddr = reinterpret_cast<uintptr_t>(Contents.get()); - LocalAddr = RoundUpToAlignment(LocalAddr, Align); + LocalAddr = alignTo(LocalAddr, Align); return reinterpret_cast<char *>(LocalAddr); } @@ -298,12 +399,11 @@ public: uint64_t Size; unsigned Align; std::unique_ptr<char[]> Contents; - TargetAddress RemoteAddr; + TargetAddress RemoteAddr = 0; }; struct ObjectAllocs { - ObjectAllocs() - : RemoteCodeAddr(0), RemoteRODataAddr(0), RemoteRWDataAddr(0) {} + ObjectAllocs() = default; ObjectAllocs(ObjectAllocs &&Other) : RemoteCodeAddr(std::move(Other.RemoteCodeAddr)), @@ -323,9 +423,9 @@ public: return *this; } - TargetAddress RemoteCodeAddr; - TargetAddress RemoteRODataAddr; - TargetAddress RemoteRWDataAddr; + TargetAddress RemoteCodeAddr = 0; + TargetAddress RemoteRODataAddr = 0; + TargetAddress RemoteRWDataAddr = 0; std::vector<Alloc> CodeAllocs, RODataAllocs, RWDataAllocs; }; @@ -333,6 +433,7 @@ public: ResourceIdMgr::ResourceId Id; std::vector<ObjectAllocs> Unmapped; std::vector<ObjectAllocs> Unfinalized; + std::vector<std::pair<uint64_t, uint32_t>> UnfinalizedEHFrames; }; /// Remote indirect stubs manager. @@ -342,26 +443,31 @@ public: ResourceIdMgr::ResourceId Id) : Remote(Remote), Id(Id) {} - ~RCIndirectStubsManager() { Remote.destroyIndirectStubsManager(Id); } + ~RCIndirectStubsManager() override { + if (auto Err = Remote.destroyIndirectStubsManager(Id)) { + // FIXME: Thread this error back to clients. + consumeError(std::move(Err)); + } + } - std::error_code createStub(StringRef StubName, TargetAddress StubAddr, - JITSymbolFlags StubFlags) override { - if (auto EC = reserveStubs(1)) - return EC; + Error createStub(StringRef StubName, TargetAddress StubAddr, + JITSymbolFlags StubFlags) override { + if (auto Err = reserveStubs(1)) + return Err; return createStubInternal(StubName, StubAddr, StubFlags); } - std::error_code createStubs(const StubInitsMap &StubInits) override { - if (auto EC = reserveStubs(StubInits.size())) - return EC; + Error createStubs(const StubInitsMap &StubInits) override { + if (auto Err = reserveStubs(StubInits.size())) + return Err; for (auto &Entry : StubInits) - if (auto EC = createStubInternal(Entry.first(), Entry.second.first, - Entry.second.second)) - return EC; + if (auto Err = createStubInternal(Entry.first(), Entry.second.first, + Entry.second.second)) + return Err; - return std::error_code(); + return Error::success(); } JITSymbol findStub(StringRef Name, bool ExportedStubsOnly) override { @@ -385,8 +491,7 @@ public: return JITSymbol(getPtrAddr(Key), Flags); } - std::error_code updatePointer(StringRef Name, - TargetAddress NewAddr) override { + Error updatePointer(StringRef Name, TargetAddress NewAddr) override { auto I = StubIndexes.find(Name); assert(I != StubIndexes.end() && "No stub pointer for symbol"); auto Key = I->second.first; @@ -395,9 +500,6 @@ public: private: struct RemoteIndirectStubsInfo { - RemoteIndirectStubsInfo(TargetAddress StubBase, TargetAddress PtrBase, - unsigned NumStubs) - : StubBase(StubBase), PtrBase(PtrBase), NumStubs(NumStubs) {} TargetAddress StubBase; TargetAddress PtrBase; unsigned NumStubs; @@ -410,31 +512,31 @@ public: std::vector<StubKey> FreeStubs; StringMap<std::pair<StubKey, JITSymbolFlags>> StubIndexes; - std::error_code reserveStubs(unsigned NumStubs) { + Error reserveStubs(unsigned NumStubs) { if (NumStubs <= FreeStubs.size()) - return std::error_code(); + return Error::success(); unsigned NewStubsRequired = NumStubs - FreeStubs.size(); TargetAddress StubBase; TargetAddress PtrBase; unsigned NumStubsEmitted; - Remote.emitIndirectStubs(StubBase, PtrBase, NumStubsEmitted, Id, - NewStubsRequired); + if (auto StubInfoOrErr = Remote.emitIndirectStubs(Id, NewStubsRequired)) + std::tie(StubBase, PtrBase, NumStubsEmitted) = *StubInfoOrErr; + else + return StubInfoOrErr.takeError(); unsigned NewBlockId = RemoteIndirectStubsInfos.size(); - RemoteIndirectStubsInfos.push_back( - RemoteIndirectStubsInfo(StubBase, PtrBase, NumStubsEmitted)); + RemoteIndirectStubsInfos.push_back({StubBase, PtrBase, NumStubsEmitted}); for (unsigned I = 0; I < NumStubsEmitted; ++I) FreeStubs.push_back(std::make_pair(NewBlockId, I)); - return std::error_code(); + return Error::success(); } - std::error_code createStubInternal(StringRef StubName, - TargetAddress InitAddr, - JITSymbolFlags StubFlags) { + Error createStubInternal(StringRef StubName, TargetAddress InitAddr, + JITSymbolFlags StubFlags) { auto Key = FreeStubs.back(); FreeStubs.pop_back(); StubIndexes[StubName] = std::make_pair(Key, StubFlags); @@ -461,20 +563,18 @@ public: public: RCCompileCallbackManager(TargetAddress ErrorHandlerAddress, OrcRemoteTargetClient &Remote) - : JITCompileCallbackManager(ErrorHandlerAddress), Remote(Remote) { - assert(!Remote.CompileCallback && "Compile callback already set"); - Remote.CompileCallback = [this](TargetAddress TrampolineAddr) { - return executeCompileCallback(TrampolineAddr); - }; - Remote.emitResolverBlock(); - } + : JITCompileCallbackManager(ErrorHandlerAddress), Remote(Remote) {} private: - void grow() { + void grow() override { TargetAddress BlockAddr = 0; uint32_t NumTrampolines = 0; - auto EC = Remote.emitTrampolineBlock(BlockAddr, NumTrampolines); - assert(!EC && "Failed to create trampolines"); + if (auto TrampolineInfoOrErr = Remote.emitTrampolineBlock()) + std::tie(BlockAddr, NumTrampolines) = *TrampolineInfoOrErr; + else { + // FIXME: Return error. + llvm_unreachable("Failed to create trampolines"); + } uint32_t TrampolineSize = Remote.getTrampolineSize(); for (unsigned I = 0; I < NumTrampolines; ++I) @@ -487,143 +587,123 @@ public: /// Create an OrcRemoteTargetClient. /// Channel is the ChannelT instance to communicate on. It is assumed that /// the channel is ready to be read from and written to. - static ErrorOr<OrcRemoteTargetClient> Create(ChannelT &Channel) { - std::error_code EC; - OrcRemoteTargetClient H(Channel, EC); - if (EC) - return EC; - return H; + static Expected<OrcRemoteTargetClient> Create(ChannelT &Channel) { + Error Err; + OrcRemoteTargetClient H(Channel, Err); + if (Err) + return std::move(Err); + return Expected<OrcRemoteTargetClient>(std::move(H)); } /// Call the int(void) function at the given address in the target and return /// its result. - std::error_code callIntVoid(int &Result, TargetAddress Addr) { + Expected<int> callIntVoid(TargetAddress Addr) { DEBUG(dbgs() << "Calling int(*)(void) " << format("0x%016x", Addr) << "\n"); - if (auto EC = call<CallIntVoid>(Channel, Addr)) - return EC; - - unsigned NextProcId; - if (auto EC = listenForCompileRequests(NextProcId)) - return EC; - - if (NextProcId != CallIntVoidResponseId) - return orcError(OrcErrorCode::UnexpectedRPCCall); - - return handle<CallIntVoidResponse>(Channel, [&](int R) { - Result = R; - DEBUG(dbgs() << "Result: " << R << "\n"); - return std::error_code(); - }); + auto Listen = [&](RPCChannel &C, uint32_t Id) { + return listenForCompileRequests(C, Id); + }; + return callSTHandling<CallIntVoid>(Channel, Listen, Addr); } /// Call the int(int, char*[]) function at the given address in the target and /// return its result. - std::error_code callMain(int &Result, TargetAddress Addr, - const std::vector<std::string> &Args) { + Expected<int> callMain(TargetAddress Addr, + const std::vector<std::string> &Args) { DEBUG(dbgs() << "Calling int(*)(int, char*[]) " << format("0x%016x", Addr) << "\n"); - if (auto EC = call<CallMain>(Channel, Addr, Args)) - return EC; - - unsigned NextProcId; - if (auto EC = listenForCompileRequests(NextProcId)) - return EC; - - if (NextProcId != CallMainResponseId) - return orcError(OrcErrorCode::UnexpectedRPCCall); - - return handle<CallMainResponse>(Channel, [&](int R) { - Result = R; - DEBUG(dbgs() << "Result: " << R << "\n"); - return std::error_code(); - }); + auto Listen = [&](RPCChannel &C, uint32_t Id) { + return listenForCompileRequests(C, Id); + }; + return callSTHandling<CallMain>(Channel, Listen, Addr, Args); } /// Call the void() function at the given address in the target and wait for /// it to finish. - std::error_code callVoidVoid(TargetAddress Addr) { + Error callVoidVoid(TargetAddress Addr) { DEBUG(dbgs() << "Calling void(*)(void) " << format("0x%016x", Addr) << "\n"); - if (auto EC = call<CallVoidVoid>(Channel, Addr)) - return EC; - - unsigned NextProcId; - if (auto EC = listenForCompileRequests(NextProcId)) - return EC; - - if (NextProcId != CallVoidVoidResponseId) - return orcError(OrcErrorCode::UnexpectedRPCCall); - - return handle<CallVoidVoidResponse>(Channel, doNothing); + auto Listen = [&](RPCChannel &C, uint32_t Id) { + return listenForCompileRequests(C, Id); + }; + return callSTHandling<CallVoidVoid>(Channel, Listen, Addr); } /// Create an RCMemoryManager which will allocate its memory on the remote /// target. - std::error_code - createRemoteMemoryManager(std::unique_ptr<RCMemoryManager> &MM) { + Error createRemoteMemoryManager(std::unique_ptr<RCMemoryManager> &MM) { assert(!MM && "MemoryManager should be null before creation."); auto Id = AllocatorIds.getNext(); - if (auto EC = call<CreateRemoteAllocator>(Channel, Id)) - return EC; + if (auto Err = callST<CreateRemoteAllocator>(Channel, Id)) + return Err; MM = llvm::make_unique<RCMemoryManager>(*this, Id); - return std::error_code(); + return Error::success(); } /// Create an RCIndirectStubsManager that will allocate stubs on the remote /// target. - std::error_code - createIndirectStubsManager(std::unique_ptr<RCIndirectStubsManager> &I) { + Error createIndirectStubsManager(std::unique_ptr<RCIndirectStubsManager> &I) { assert(!I && "Indirect stubs manager should be null before creation."); auto Id = IndirectStubOwnerIds.getNext(); - if (auto EC = call<CreateIndirectStubsOwner>(Channel, Id)) - return EC; + if (auto Err = callST<CreateIndirectStubsOwner>(Channel, Id)) + return Err; I = llvm::make_unique<RCIndirectStubsManager>(*this, Id); - return std::error_code(); + return Error::success(); + } + + Expected<RCCompileCallbackManager &> + enableCompileCallbacks(TargetAddress ErrorHandlerAddress) { + // Check for an 'out-of-band' error, e.g. from an MM destructor. + if (ExistingError) + return std::move(ExistingError); + + // Emit the resolver block on the JIT server. + if (auto Err = callST<EmitResolverBlock>(Channel)) + return std::move(Err); + + // Create the callback manager. + CallbackManager.emplace(ErrorHandlerAddress, *this); + RCCompileCallbackManager &Mgr = *CallbackManager; + return Mgr; } /// Search for symbols in the remote process. Note: This should be used by /// symbol resolvers *after* they've searched the local symbol table in the /// JIT stack. - std::error_code getSymbolAddress(TargetAddress &Addr, StringRef Name) { + Expected<TargetAddress> getSymbolAddress(StringRef Name) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) - return ExistingError; - - // Request remote symbol address. - if (auto EC = call<GetSymbolAddress>(Channel, Name)) - return EC; - - return expect<GetSymbolAddressResponse>(Channel, [&](TargetAddress &A) { - Addr = A; - DEBUG(dbgs() << "Remote address lookup " << Name << " = " - << format("0x%016x", Addr) << "\n"); - return std::error_code(); - }); + return std::move(ExistingError); + + return callST<GetSymbolAddress>(Channel, Name); } /// Get the triple for the remote target. const std::string &getTargetTriple() const { return RemoteTargetTriple; } - std::error_code terminateSession() { return call<TerminateSession>(Channel); } + Error terminateSession() { return callST<TerminateSession>(Channel); } private: - OrcRemoteTargetClient(ChannelT &Channel, std::error_code &EC) - : Channel(Channel), RemotePointerSize(0), RemotePageSize(0), - RemoteTrampolineSize(0), RemoteIndirectStubSize(0) { - if ((EC = call<GetRemoteInfo>(Channel))) - return; - - EC = expect<GetRemoteInfoResponse>( - Channel, readArgs(RemoteTargetTriple, RemotePointerSize, RemotePageSize, - RemoteTrampolineSize, RemoteIndirectStubSize)); + OrcRemoteTargetClient(ChannelT &Channel, Error &Err) : Channel(Channel) { + ErrorAsOutParameter EAO(Err); + if (auto RIOrErr = callST<GetRemoteInfo>(Channel)) { + std::tie(RemoteTargetTriple, RemotePointerSize, RemotePageSize, + RemoteTrampolineSize, RemoteIndirectStubSize) = *RIOrErr; + Err = Error::success(); + } else { + Err = joinErrors(RIOrErr.takeError(), std::move(ExistingError)); + } + } + + Error deregisterEHFrames(TargetAddress Addr, uint32_t Size) { + return callST<RegisterEHFrames>(Channel, Addr, Size); } void destroyRemoteAllocator(ResourceIdMgr::ResourceId Id) { - if (auto EC = call<DestroyRemoteAllocator>(Channel, Id)) { + if (auto Err = callST<DestroyRemoteAllocator>(Channel, Id)) { // FIXME: This will be triggered by a removeModuleSet call: Propagate // error return up through that. llvm_unreachable("Failed to destroy remote allocator."); @@ -631,46 +711,22 @@ private: } } - std::error_code destroyIndirectStubsManager(ResourceIdMgr::ResourceId Id) { + Error destroyIndirectStubsManager(ResourceIdMgr::ResourceId Id) { IndirectStubOwnerIds.release(Id); - return call<DestroyIndirectStubsOwner>(Channel, Id); + return callST<DestroyIndirectStubsOwner>(Channel, Id); } - std::error_code emitIndirectStubs(TargetAddress &StubBase, - TargetAddress &PtrBase, - uint32_t &NumStubsEmitted, - ResourceIdMgr::ResourceId Id, - uint32_t NumStubsRequired) { - if (auto EC = call<EmitIndirectStubs>(Channel, Id, NumStubsRequired)) - return EC; - - return expect<EmitIndirectStubsResponse>( - Channel, readArgs(StubBase, PtrBase, NumStubsEmitted)); + Expected<std::tuple<TargetAddress, TargetAddress, uint32_t>> + emitIndirectStubs(ResourceIdMgr::ResourceId Id, uint32_t NumStubsRequired) { + return callST<EmitIndirectStubs>(Channel, Id, NumStubsRequired); } - std::error_code emitResolverBlock() { + Expected<std::tuple<TargetAddress, uint32_t>> emitTrampolineBlock() { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) - return ExistingError; + return std::move(ExistingError); - return call<EmitResolverBlock>(Channel); - } - - std::error_code emitTrampolineBlock(TargetAddress &BlockAddr, - uint32_t &NumTrampolines) { - // Check for an 'out-of-band' error, e.g. from an MM destructor. - if (ExistingError) - return ExistingError; - - if (auto EC = call<EmitTrampolineBlock>(Channel)) - return EC; - - return expect<EmitTrampolineBlockResponse>( - Channel, [&](TargetAddress BAddr, uint32_t NTrampolines) { - BlockAddr = BAddr; - NumTrampolines = NTrampolines; - return std::error_code(); - }); + return callST<EmitTrampolineBlock>(Channel); } uint32_t getIndirectStubSize() const { return RemoteIndirectStubSize; } @@ -679,100 +735,86 @@ private: uint32_t getTrampolineSize() const { return RemoteTrampolineSize; } - std::error_code listenForCompileRequests(uint32_t &NextId) { + Error listenForCompileRequests(RPCChannel &C, uint32_t &Id) { + assert(CallbackManager && + "No calback manager. enableCompileCallbacks must be called first"); + // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) - return ExistingError; - - if (auto EC = getNextProcId(Channel, NextId)) - return EC; - - while (NextId == RequestCompileId) { - TargetAddress TrampolineAddr = 0; - if (auto EC = handle<RequestCompile>(Channel, readArgs(TrampolineAddr))) - return EC; - - TargetAddress ImplAddr = CompileCallback(TrampolineAddr); - if (auto EC = call<RequestCompileResponse>(Channel, ImplAddr)) - return EC; + return std::move(ExistingError); + + // FIXME: CompileCallback could be an anonymous lambda defined at the use + // site below, but that triggers a GCC 4.7 ICE. When we move off + // GCC 4.7, tidy this up. + auto CompileCallback = + [this](TargetAddress Addr) -> Expected<TargetAddress> { + return this->CallbackManager->executeCompileCallback(Addr); + }; - if (auto EC = getNextProcId(Channel, NextId)) - return EC; + if (Id == RequestCompileId) { + if (auto Err = handle<RequestCompile>(C, CompileCallback)) + return Err; + return Error::success(); } - - return std::error_code(); + // else + return orcError(OrcErrorCode::UnexpectedRPCCall); } - std::error_code readMem(char *Dst, TargetAddress Src, uint64_t Size) { + Expected<std::vector<char>> readMem(char *Dst, TargetAddress Src, + uint64_t Size) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) - return ExistingError; - - if (auto EC = call<ReadMem>(Channel, Src, Size)) - return EC; + return std::move(ExistingError); - if (auto EC = expect<ReadMemResponse>( - Channel, [&]() { return Channel.readBytes(Dst, Size); })) - return EC; + return callST<ReadMem>(Channel, Src, Size); + } - return std::error_code(); + Error registerEHFrames(TargetAddress &RAddr, uint32_t Size) { + return callST<RegisterEHFrames>(Channel, RAddr, Size); } - std::error_code reserveMem(TargetAddress &RemoteAddr, - ResourceIdMgr::ResourceId Id, uint64_t Size, - uint32_t Align) { + Expected<TargetAddress> reserveMem(ResourceIdMgr::ResourceId Id, + uint64_t Size, uint32_t Align) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) - return ExistingError; - - if (std::error_code EC = call<ReserveMem>(Channel, Id, Size, Align)) - return EC; + return std::move(ExistingError); - return expect<ReserveMemResponse>(Channel, readArgs(RemoteAddr)); + return callST<ReserveMem>(Channel, Id, Size, Align); } - std::error_code setProtections(ResourceIdMgr::ResourceId Id, - TargetAddress RemoteSegAddr, - unsigned ProtFlags) { - return call<SetProtections>(Channel, Id, RemoteSegAddr, ProtFlags); + Error setProtections(ResourceIdMgr::ResourceId Id, + TargetAddress RemoteSegAddr, unsigned ProtFlags) { + return callST<SetProtections>(Channel, Id, RemoteSegAddr, ProtFlags); } - std::error_code writeMem(TargetAddress Addr, const char *Src, uint64_t Size) { + Error writeMem(TargetAddress Addr, const char *Src, uint64_t Size) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) - return ExistingError; - - // Make the send call. - if (auto EC = call<WriteMem>(Channel, Addr, Size)) - return EC; - - // Follow this up with the section contents. - if (auto EC = Channel.appendBytes(Src, Size)) - return EC; + return std::move(ExistingError); - return Channel.send(); + return callST<WriteMem>(Channel, DirectBufferWriter(Src, Addr, Size)); } - std::error_code writePointer(TargetAddress Addr, TargetAddress PtrVal) { + Error writePointer(TargetAddress Addr, TargetAddress PtrVal) { // Check for an 'out-of-band' error, e.g. from an MM destructor. if (ExistingError) - return ExistingError; + return std::move(ExistingError); - return call<WritePtr>(Channel, Addr, PtrVal); + return callST<WritePtr>(Channel, Addr, PtrVal); } - static std::error_code doNothing() { return std::error_code(); } + static Error doNothing() { return Error::success(); } ChannelT &Channel; - std::error_code ExistingError; + Error ExistingError; std::string RemoteTargetTriple; - uint32_t RemotePointerSize; - uint32_t RemotePageSize; - uint32_t RemoteTrampolineSize; - uint32_t RemoteIndirectStubSize; + uint32_t RemotePointerSize = 0; + uint32_t RemotePageSize = 0; + uint32_t RemoteTrampolineSize = 0; + uint32_t RemoteIndirectStubSize = 0; ResourceIdMgr AllocatorIds, IndirectStubOwnerIds; - std::function<TargetAddress(TargetAddress)> CompileCallback; + Optional<RCCompileCallbackManager> CallbackManager; }; } // end namespace remote @@ -781,4 +823,4 @@ private: #undef DEBUG_TYPE -#endif +#endif // LLVM_EXECUTIONENGINE_ORC_ORCREMOTETARGETCLIENT_H diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h index 96dc24251026..74d851522f79 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetRPCAPI.h @@ -24,12 +24,51 @@ namespace llvm { namespace orc { namespace remote { +class DirectBufferWriter { +public: + DirectBufferWriter() = default; + DirectBufferWriter(const char *Src, TargetAddress Dst, uint64_t Size) + : Src(Src), Dst(Dst), Size(Size) {} + + const char *getSrc() const { return Src; } + TargetAddress getDst() const { return Dst; } + uint64_t getSize() const { return Size; } + +private: + const char *Src; + TargetAddress Dst; + uint64_t Size; +}; + +inline Error serialize(RPCChannel &C, const DirectBufferWriter &DBW) { + if (auto EC = serialize(C, DBW.getDst())) + return EC; + if (auto EC = serialize(C, DBW.getSize())) + return EC; + return C.appendBytes(DBW.getSrc(), DBW.getSize()); +} + +inline Error deserialize(RPCChannel &C, DirectBufferWriter &DBW) { + TargetAddress Dst; + if (auto EC = deserialize(C, Dst)) + return EC; + uint64_t Size; + if (auto EC = deserialize(C, Size)) + return EC; + char *Addr = reinterpret_cast<char *>(static_cast<uintptr_t>(Dst)); + + DBW = DirectBufferWriter(0, Dst, Size); + + return C.readBytes(Addr, Size); +} + class OrcRemoteTargetRPCAPI : public RPC<RPCChannel> { protected: class ResourceIdMgr { public: typedef uint64_t ResourceId; - ResourceIdMgr() : NextId(0) {} + static const ResourceId InvalidId = ~0U; + ResourceId getNext() { if (!FreeIds.empty()) { ResourceId I = FreeIds.back(); @@ -41,140 +80,122 @@ protected: void release(ResourceId I) { FreeIds.push_back(I); } private: - ResourceId NextId; + ResourceId NextId = 0; std::vector<ResourceId> FreeIds; }; public: - enum JITProcId : uint32_t { - InvalidId = 0, - CallIntVoidId, - CallIntVoidResponseId, + // FIXME: Remove constructors once MSVC supports synthesizing move-ops. + OrcRemoteTargetRPCAPI() = default; + OrcRemoteTargetRPCAPI(const OrcRemoteTargetRPCAPI &) = delete; + OrcRemoteTargetRPCAPI &operator=(const OrcRemoteTargetRPCAPI &) = delete; + + OrcRemoteTargetRPCAPI(OrcRemoteTargetRPCAPI &&) {} + OrcRemoteTargetRPCAPI &operator=(OrcRemoteTargetRPCAPI &&) { return *this; } + + enum JITFuncId : uint32_t { + InvalidId = RPCFunctionIdTraits<JITFuncId>::InvalidId, + CallIntVoidId = RPCFunctionIdTraits<JITFuncId>::FirstValidId, CallMainId, - CallMainResponseId, CallVoidVoidId, - CallVoidVoidResponseId, CreateRemoteAllocatorId, CreateIndirectStubsOwnerId, + DeregisterEHFramesId, DestroyRemoteAllocatorId, DestroyIndirectStubsOwnerId, EmitIndirectStubsId, - EmitIndirectStubsResponseId, EmitResolverBlockId, EmitTrampolineBlockId, - EmitTrampolineBlockResponseId, GetSymbolAddressId, - GetSymbolAddressResponseId, GetRemoteInfoId, - GetRemoteInfoResponseId, ReadMemId, - ReadMemResponseId, + RegisterEHFramesId, ReserveMemId, - ReserveMemResponseId, RequestCompileId, - RequestCompileResponseId, SetProtectionsId, TerminateSessionId, WriteMemId, WritePtrId }; - static const char *getJITProcIdName(JITProcId Id); + static const char *getJITFuncIdName(JITFuncId Id); - typedef Procedure<CallIntVoidId, TargetAddress /* FnAddr */> CallIntVoid; + typedef Function<CallIntVoidId, int32_t(TargetAddress Addr)> CallIntVoid; - typedef Procedure<CallIntVoidResponseId, int /* Result */> - CallIntVoidResponse; - - typedef Procedure<CallMainId, TargetAddress /* FnAddr */, - std::vector<std::string> /* Args */> + typedef Function<CallMainId, + int32_t(TargetAddress Addr, std::vector<std::string> Args)> CallMain; - typedef Procedure<CallMainResponseId, int /* Result */> CallMainResponse; - - typedef Procedure<CallVoidVoidId, TargetAddress /* FnAddr */> CallVoidVoid; - - typedef Procedure<CallVoidVoidResponseId> CallVoidVoidResponse; + typedef Function<CallVoidVoidId, void(TargetAddress FnAddr)> CallVoidVoid; - typedef Procedure<CreateRemoteAllocatorId, - ResourceIdMgr::ResourceId /* Allocator ID */> + typedef Function<CreateRemoteAllocatorId, + void(ResourceIdMgr::ResourceId AllocatorID)> CreateRemoteAllocator; - typedef Procedure<CreateIndirectStubsOwnerId, - ResourceIdMgr::ResourceId /* StubsOwner ID */> + typedef Function<CreateIndirectStubsOwnerId, + void(ResourceIdMgr::ResourceId StubOwnerID)> CreateIndirectStubsOwner; - typedef Procedure<DestroyRemoteAllocatorId, - ResourceIdMgr::ResourceId /* Allocator ID */> + typedef Function<DeregisterEHFramesId, + void(TargetAddress Addr, uint32_t Size)> + DeregisterEHFrames; + + typedef Function<DestroyRemoteAllocatorId, + void(ResourceIdMgr::ResourceId AllocatorID)> DestroyRemoteAllocator; - typedef Procedure<DestroyIndirectStubsOwnerId, - ResourceIdMgr::ResourceId /* StubsOwner ID */> + typedef Function<DestroyIndirectStubsOwnerId, + void(ResourceIdMgr::ResourceId StubsOwnerID)> DestroyIndirectStubsOwner; - typedef Procedure<EmitIndirectStubsId, - ResourceIdMgr::ResourceId /* StubsOwner ID */, - uint32_t /* NumStubsRequired */> + /// EmitIndirectStubs result is (StubsBase, PtrsBase, NumStubsEmitted). + typedef Function<EmitIndirectStubsId, + std::tuple<TargetAddress, TargetAddress, uint32_t>( + ResourceIdMgr::ResourceId StubsOwnerID, + uint32_t NumStubsRequired)> EmitIndirectStubs; - typedef Procedure< - EmitIndirectStubsResponseId, TargetAddress /* StubsBaseAddr */, - TargetAddress /* PtrsBaseAddr */, uint32_t /* NumStubsEmitted */> - EmitIndirectStubsResponse; - - typedef Procedure<EmitResolverBlockId> EmitResolverBlock; + typedef Function<EmitResolverBlockId, void()> EmitResolverBlock; - typedef Procedure<EmitTrampolineBlockId> EmitTrampolineBlock; + /// EmitTrampolineBlock result is (BlockAddr, NumTrampolines). + typedef Function<EmitTrampolineBlockId, std::tuple<TargetAddress, uint32_t>()> + EmitTrampolineBlock; - typedef Procedure<EmitTrampolineBlockResponseId, - TargetAddress /* BlockAddr */, - uint32_t /* NumTrampolines */> - EmitTrampolineBlockResponse; - - typedef Procedure<GetSymbolAddressId, std::string /*SymbolName*/> + typedef Function<GetSymbolAddressId, TargetAddress(std::string SymbolName)> GetSymbolAddress; - typedef Procedure<GetSymbolAddressResponseId, uint64_t /* SymbolAddr */> - GetSymbolAddressResponse; - - typedef Procedure<GetRemoteInfoId> GetRemoteInfo; - - typedef Procedure<GetRemoteInfoResponseId, std::string /* Triple */, - uint32_t /* PointerSize */, uint32_t /* PageSize */, - uint32_t /* TrampolineSize */, - uint32_t /* IndirectStubSize */> - GetRemoteInfoResponse; + /// GetRemoteInfo result is (Triple, PointerSize, PageSize, TrampolineSize, + /// IndirectStubsSize). + typedef Function<GetRemoteInfoId, std::tuple<std::string, uint32_t, uint32_t, + uint32_t, uint32_t>()> + GetRemoteInfo; - typedef Procedure<ReadMemId, TargetAddress /* Src */, uint64_t /* Size */> + typedef Function<ReadMemId, + std::vector<char>(TargetAddress Src, uint64_t Size)> ReadMem; - typedef Procedure<ReadMemResponseId> ReadMemResponse; + typedef Function<RegisterEHFramesId, void(TargetAddress Addr, uint32_t Size)> + RegisterEHFrames; - typedef Procedure<ReserveMemId, ResourceIdMgr::ResourceId /* Id */, - uint64_t /* Size */, uint32_t /* Align */> + typedef Function<ReserveMemId, + TargetAddress(ResourceIdMgr::ResourceId AllocID, + uint64_t Size, uint32_t Align)> ReserveMem; - typedef Procedure<ReserveMemResponseId, TargetAddress /* Addr */> - ReserveMemResponse; - - typedef Procedure<RequestCompileId, TargetAddress /* TrampolineAddr */> + typedef Function<RequestCompileId, + TargetAddress(TargetAddress TrampolineAddr)> RequestCompile; - typedef Procedure<RequestCompileResponseId, TargetAddress /* ImplAddr */> - RequestCompileResponse; - - typedef Procedure<SetProtectionsId, ResourceIdMgr::ResourceId /* Id */, - TargetAddress /* Dst */, uint32_t /* ProtFlags */> + typedef Function<SetProtectionsId, + void(ResourceIdMgr::ResourceId AllocID, TargetAddress Dst, + uint32_t ProtFlags)> SetProtections; - typedef Procedure<TerminateSessionId> TerminateSession; + typedef Function<TerminateSessionId, void()> TerminateSession; - typedef Procedure<WriteMemId, TargetAddress /* Dst */, uint64_t /* Size */ - /* Data should follow */> - WriteMem; + typedef Function<WriteMemId, void(DirectBufferWriter DB)> WriteMem; - typedef Procedure<WritePtrId, TargetAddress /* Dst */, - TargetAddress /* Val */> + typedef Function<WritePtrId, void(TargetAddress Dst, TargetAddress Val)> WritePtr; }; diff --git a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h index 5247661e49ce..bf4299c69b24 100644 --- a/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h +++ b/include/llvm/ExecutionEngine/Orc/OrcRemoteTargetServer.h @@ -35,17 +35,31 @@ public: typedef std::function<TargetAddress(const std::string &Name)> SymbolLookupFtor; - OrcRemoteTargetServer(ChannelT &Channel, SymbolLookupFtor SymbolLookup) - : Channel(Channel), SymbolLookup(std::move(SymbolLookup)) {} + typedef std::function<void(uint8_t *Addr, uint32_t Size)> + EHFrameRegistrationFtor; - std::error_code getNextProcId(JITProcId &Id) { - return deserialize(Channel, Id); - } + OrcRemoteTargetServer(ChannelT &Channel, SymbolLookupFtor SymbolLookup, + EHFrameRegistrationFtor EHFramesRegister, + EHFrameRegistrationFtor EHFramesDeregister) + : Channel(Channel), SymbolLookup(std::move(SymbolLookup)), + EHFramesRegister(std::move(EHFramesRegister)), + EHFramesDeregister(std::move(EHFramesDeregister)) {} + + // FIXME: Remove move/copy ops once MSVC supports synthesizing move ops. + OrcRemoteTargetServer(const OrcRemoteTargetServer &) = delete; + OrcRemoteTargetServer &operator=(const OrcRemoteTargetServer &) = delete; + + OrcRemoteTargetServer(OrcRemoteTargetServer &&Other) + : Channel(Other.Channel), SymbolLookup(std::move(Other.SymbolLookup)), + EHFramesRegister(std::move(Other.EHFramesRegister)), + EHFramesDeregister(std::move(Other.EHFramesDeregister)) {} + + OrcRemoteTargetServer &operator=(OrcRemoteTargetServer &&) = delete; - std::error_code handleKnownProcedure(JITProcId Id) { + Error handleKnownFunction(JITFuncId Id) { typedef OrcRemoteTargetServer ThisT; - DEBUG(dbgs() << "Handling known proc: " << getJITProcIdName(Id) << "\n"); + DEBUG(dbgs() << "Handling known proc: " << getJITFuncIdName(Id) << "\n"); switch (Id) { case CallIntVoidId: @@ -60,6 +74,9 @@ public: case CreateIndirectStubsOwnerId: return handle<CreateIndirectStubsOwner>( Channel, *this, &ThisT::handleCreateIndirectStubsOwner); + case DeregisterEHFramesId: + return handle<DeregisterEHFrames>(Channel, *this, + &ThisT::handleDeregisterEHFrames); case DestroyRemoteAllocatorId: return handle<DestroyRemoteAllocator>( Channel, *this, &ThisT::handleDestroyRemoteAllocator); @@ -82,6 +99,9 @@ public: return handle<GetRemoteInfo>(Channel, *this, &ThisT::handleGetRemoteInfo); case ReadMemId: return handle<ReadMem>(Channel, *this, &ThisT::handleReadMem); + case RegisterEHFramesId: + return handle<RegisterEHFrames>(Channel, *this, + &ThisT::handleRegisterEHFrames); case ReserveMemId: return handle<ReserveMem>(Channel, *this, &ThisT::handleReserveMem); case SetProtectionsId: @@ -98,27 +118,16 @@ public: llvm_unreachable("Unhandled JIT RPC procedure Id."); } - std::error_code requestCompile(TargetAddress &CompiledFnAddr, - TargetAddress TrampolineAddr) { - if (auto EC = call<RequestCompile>(Channel, TrampolineAddr)) - return EC; - - while (1) { - JITProcId Id = InvalidId; - if (auto EC = getNextProcId(Id)) - return EC; - - switch (Id) { - case RequestCompileResponseId: - return handle<RequestCompileResponse>(Channel, - readArgs(CompiledFnAddr)); - default: - if (auto EC = handleKnownProcedure(Id)) - return EC; - } - } + Expected<TargetAddress> requestCompile(TargetAddress TrampolineAddr) { + auto Listen = [&](RPCChannel &C, uint32_t Id) { + return handleKnownFunction(static_cast<JITFuncId>(Id)); + }; + + return callSTHandling<RequestCompile>(Channel, Listen, TrampolineAddr); + } - llvm_unreachable("Fell through request-compile command loop."); + Error handleTerminateSession() { + return handle<TerminateSession>(Channel, []() { return Error::success(); }); } private: @@ -135,60 +144,56 @@ private: sys::Memory::releaseMappedMemory(Alloc.second); } - std::error_code allocate(void *&Addr, size_t Size, uint32_t Align) { + Error allocate(void *&Addr, size_t Size, uint32_t Align) { std::error_code EC; sys::MemoryBlock MB = sys::Memory::allocateMappedMemory( Size, nullptr, sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC); if (EC) - return EC; + return errorCodeToError(EC); Addr = MB.base(); assert(Allocs.find(MB.base()) == Allocs.end() && "Duplicate alloc"); Allocs[MB.base()] = std::move(MB); - return std::error_code(); + return Error::success(); } - std::error_code setProtections(void *block, unsigned Flags) { + Error setProtections(void *block, unsigned Flags) { auto I = Allocs.find(block); if (I == Allocs.end()) return orcError(OrcErrorCode::RemoteMProtectAddrUnrecognized); - return sys::Memory::protectMappedMemory(I->second, Flags); + return errorCodeToError( + sys::Memory::protectMappedMemory(I->second, Flags)); } private: std::map<void *, sys::MemoryBlock> Allocs; }; - static std::error_code doNothing() { return std::error_code(); } + static Error doNothing() { return Error::success(); } static TargetAddress reenter(void *JITTargetAddr, void *TrampolineAddr) { - TargetAddress CompiledFnAddr = 0; - auto T = static_cast<OrcRemoteTargetServer *>(JITTargetAddr); - auto EC = T->requestCompile( - CompiledFnAddr, static_cast<TargetAddress>( - reinterpret_cast<uintptr_t>(TrampolineAddr))); - assert(!EC && "Compile request failed"); - (void)EC; - return CompiledFnAddr; + auto AddrOrErr = T->requestCompile(static_cast<TargetAddress>( + reinterpret_cast<uintptr_t>(TrampolineAddr))); + // FIXME: Allow customizable failure substitution functions. + assert(AddrOrErr && "Compile request failed"); + return *AddrOrErr; } - std::error_code handleCallIntVoid(TargetAddress Addr) { + Expected<int32_t> handleCallIntVoid(TargetAddress Addr) { typedef int (*IntVoidFnTy)(); IntVoidFnTy Fn = reinterpret_cast<IntVoidFnTy>(static_cast<uintptr_t>(Addr)); - DEBUG(dbgs() << " Calling " - << reinterpret_cast<void *>(reinterpret_cast<intptr_t>(Fn)) - << "\n"); + DEBUG(dbgs() << " Calling " << format("0x%016x", Addr) << "\n"); int Result = Fn(); DEBUG(dbgs() << " Result = " << Result << "\n"); - return call<CallIntVoidResponse>(Channel, Result); + return Result; } - std::error_code handleCallMain(TargetAddress Addr, - std::vector<std::string> Args) { + Expected<int32_t> handleCallMain(TargetAddress Addr, + std::vector<std::string> Args) { typedef int (*MainFnTy)(int, const char *[]); MainFnTy Fn = reinterpret_cast<MainFnTy>(static_cast<uintptr_t>(Addr)); @@ -199,63 +204,71 @@ private: for (auto &Arg : Args) ArgV[Idx++] = Arg.c_str(); - DEBUG(dbgs() << " Calling " << reinterpret_cast<void *>(Fn) << "\n"); + DEBUG(dbgs() << " Calling " << format("0x%016x", Addr) << "\n"); int Result = Fn(ArgC, ArgV.get()); DEBUG(dbgs() << " Result = " << Result << "\n"); - return call<CallMainResponse>(Channel, Result); + return Result; } - std::error_code handleCallVoidVoid(TargetAddress Addr) { + Error handleCallVoidVoid(TargetAddress Addr) { typedef void (*VoidVoidFnTy)(); VoidVoidFnTy Fn = reinterpret_cast<VoidVoidFnTy>(static_cast<uintptr_t>(Addr)); - DEBUG(dbgs() << " Calling " << reinterpret_cast<void *>(Fn) << "\n"); + DEBUG(dbgs() << " Calling " << format("0x%016x", Addr) << "\n"); Fn(); DEBUG(dbgs() << " Complete.\n"); - return call<CallVoidVoidResponse>(Channel); + return Error::success(); } - std::error_code handleCreateRemoteAllocator(ResourceIdMgr::ResourceId Id) { + Error handleCreateRemoteAllocator(ResourceIdMgr::ResourceId Id) { auto I = Allocators.find(Id); if (I != Allocators.end()) return orcError(OrcErrorCode::RemoteAllocatorIdAlreadyInUse); DEBUG(dbgs() << " Created allocator " << Id << "\n"); Allocators[Id] = Allocator(); - return std::error_code(); + return Error::success(); } - std::error_code handleCreateIndirectStubsOwner(ResourceIdMgr::ResourceId Id) { + Error handleCreateIndirectStubsOwner(ResourceIdMgr::ResourceId Id) { auto I = IndirectStubsOwners.find(Id); if (I != IndirectStubsOwners.end()) return orcError(OrcErrorCode::RemoteIndirectStubsOwnerIdAlreadyInUse); DEBUG(dbgs() << " Create indirect stubs owner " << Id << "\n"); IndirectStubsOwners[Id] = ISBlockOwnerList(); - return std::error_code(); + return Error::success(); + } + + Error handleDeregisterEHFrames(TargetAddress TAddr, uint32_t Size) { + uint8_t *Addr = reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(TAddr)); + DEBUG(dbgs() << " Registering EH frames at " << format("0x%016x", TAddr) + << ", Size = " << Size << " bytes\n"); + EHFramesDeregister(Addr, Size); + return Error::success(); } - std::error_code handleDestroyRemoteAllocator(ResourceIdMgr::ResourceId Id) { + Error handleDestroyRemoteAllocator(ResourceIdMgr::ResourceId Id) { auto I = Allocators.find(Id); if (I == Allocators.end()) return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); Allocators.erase(I); DEBUG(dbgs() << " Destroyed allocator " << Id << "\n"); - return std::error_code(); + return Error::success(); } - std::error_code - handleDestroyIndirectStubsOwner(ResourceIdMgr::ResourceId Id) { + Error handleDestroyIndirectStubsOwner(ResourceIdMgr::ResourceId Id) { auto I = IndirectStubsOwners.find(Id); if (I == IndirectStubsOwners.end()) return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist); IndirectStubsOwners.erase(I); - return std::error_code(); + return Error::success(); } - std::error_code handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id, - uint32_t NumStubsRequired) { + Expected<std::tuple<TargetAddress, TargetAddress, uint32_t>> + handleEmitIndirectStubs(ResourceIdMgr::ResourceId Id, + uint32_t NumStubsRequired) { DEBUG(dbgs() << " ISMgr " << Id << " request " << NumStubsRequired << " stubs.\n"); @@ -264,9 +277,9 @@ private: return orcError(OrcErrorCode::RemoteIndirectStubsOwnerDoesNotExist); typename TargetT::IndirectStubsInfo IS; - if (auto EC = + if (auto Err = TargetT::emitIndirectStubsBlock(IS, NumStubsRequired, nullptr)) - return EC; + return std::move(Err); TargetAddress StubsBase = static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(IS.getStub(0))); @@ -277,36 +290,35 @@ private: auto &BlockList = StubOwnerItr->second; BlockList.push_back(std::move(IS)); - return call<EmitIndirectStubsResponse>(Channel, StubsBase, PtrsBase, - NumStubsEmitted); + return std::make_tuple(StubsBase, PtrsBase, NumStubsEmitted); } - std::error_code handleEmitResolverBlock() { + Error handleEmitResolverBlock() { std::error_code EC; ResolverBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory( TargetT::ResolverCodeSize, nullptr, sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC)); if (EC) - return EC; + return errorCodeToError(EC); TargetT::writeResolverCode(static_cast<uint8_t *>(ResolverBlock.base()), &reenter, this); - return sys::Memory::protectMappedMemory(ResolverBlock.getMemoryBlock(), - sys::Memory::MF_READ | - sys::Memory::MF_EXEC); + return errorCodeToError(sys::Memory::protectMappedMemory( + ResolverBlock.getMemoryBlock(), + sys::Memory::MF_READ | sys::Memory::MF_EXEC)); } - std::error_code handleEmitTrampolineBlock() { + Expected<std::tuple<TargetAddress, uint32_t>> handleEmitTrampolineBlock() { std::error_code EC; auto TrampolineBlock = sys::OwningMemoryBlock(sys::Memory::allocateMappedMemory( sys::Process::getPageSize(), nullptr, sys::Memory::MF_READ | sys::Memory::MF_WRITE, EC)); if (EC) - return EC; + return errorCodeToError(EC); - unsigned NumTrampolines = + uint32_t NumTrampolines = (sys::Process::getPageSize() - TargetT::PointerSize) / TargetT::TrampolineSize; @@ -320,20 +332,21 @@ private: TrampolineBlocks.push_back(std::move(TrampolineBlock)); - return call<EmitTrampolineBlockResponse>( - Channel, - static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineMem)), - NumTrampolines); + auto TrampolineBaseAddr = + static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(TrampolineMem)); + + return std::make_tuple(TrampolineBaseAddr, NumTrampolines); } - std::error_code handleGetSymbolAddress(const std::string &Name) { + Expected<TargetAddress> handleGetSymbolAddress(const std::string &Name) { TargetAddress Addr = SymbolLookup(Name); DEBUG(dbgs() << " Symbol '" << Name << "' = " << format("0x%016x", Addr) << "\n"); - return call<GetSymbolAddressResponse>(Channel, Addr); + return Addr; } - std::error_code handleGetRemoteInfo() { + Expected<std::tuple<std::string, uint32_t, uint32_t, uint32_t, uint32_t>> + handleGetRemoteInfo() { std::string ProcessTriple = sys::getProcessTriple(); uint32_t PointerSize = TargetT::PointerSize; uint32_t PageSize = sys::Process::getPageSize(); @@ -345,35 +358,41 @@ private: << " page size = " << PageSize << "\n" << " trampoline size = " << TrampolineSize << "\n" << " indirect stub size = " << IndirectStubSize << "\n"); - return call<GetRemoteInfoResponse>(Channel, ProcessTriple, PointerSize, - PageSize, TrampolineSize, - IndirectStubSize); + return std::make_tuple(ProcessTriple, PointerSize, PageSize, TrampolineSize, + IndirectStubSize); } - std::error_code handleReadMem(TargetAddress RSrc, uint64_t Size) { + Expected<std::vector<char>> handleReadMem(TargetAddress RSrc, uint64_t Size) { char *Src = reinterpret_cast<char *>(static_cast<uintptr_t>(RSrc)); DEBUG(dbgs() << " Reading " << Size << " bytes from " - << static_cast<void *>(Src) << "\n"); + << format("0x%016x", RSrc) << "\n"); - if (auto EC = call<ReadMemResponse>(Channel)) - return EC; + std::vector<char> Buffer; + Buffer.resize(Size); + for (char *P = Src; Size != 0; --Size) + Buffer.push_back(*P++); - if (auto EC = Channel.appendBytes(Src, Size)) - return EC; + return Buffer; + } - return Channel.send(); + Error handleRegisterEHFrames(TargetAddress TAddr, uint32_t Size) { + uint8_t *Addr = reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(TAddr)); + DEBUG(dbgs() << " Registering EH frames at " << format("0x%016x", TAddr) + << ", Size = " << Size << " bytes\n"); + EHFramesRegister(Addr, Size); + return Error::success(); } - std::error_code handleReserveMem(ResourceIdMgr::ResourceId Id, uint64_t Size, - uint32_t Align) { + Expected<TargetAddress> handleReserveMem(ResourceIdMgr::ResourceId Id, + uint64_t Size, uint32_t Align) { auto I = Allocators.find(Id); if (I == Allocators.end()) return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); auto &Allocator = I->second; void *LocalAllocAddr = nullptr; - if (auto EC = Allocator.allocate(LocalAllocAddr, Size, Align)) - return EC; + if (auto Err = Allocator.allocate(LocalAllocAddr, Size, Align)) + return std::move(Err); DEBUG(dbgs() << " Allocator " << Id << " reserved " << LocalAllocAddr << " (" << Size << " bytes, alignment " << Align << ")\n"); @@ -381,11 +400,11 @@ private: TargetAddress AllocAddr = static_cast<TargetAddress>(reinterpret_cast<uintptr_t>(LocalAllocAddr)); - return call<ReserveMemResponse>(Channel, AllocAddr); + return AllocAddr; } - std::error_code handleSetProtections(ResourceIdMgr::ResourceId Id, - TargetAddress Addr, uint32_t Flags) { + Error handleSetProtections(ResourceIdMgr::ResourceId Id, TargetAddress Addr, + uint32_t Flags) { auto I = Allocators.find(Id); if (I == Allocators.end()) return orcError(OrcErrorCode::RemoteAllocatorDoesNotExist); @@ -398,24 +417,24 @@ private: return Allocator.setProtections(LocalAddr, Flags); } - std::error_code handleWriteMem(TargetAddress RDst, uint64_t Size) { - char *Dst = reinterpret_cast<char *>(static_cast<uintptr_t>(RDst)); - DEBUG(dbgs() << " Writing " << Size << " bytes to " - << format("0x%016x", RDst) << "\n"); - return Channel.readBytes(Dst, Size); + Error handleWriteMem(DirectBufferWriter DBW) { + DEBUG(dbgs() << " Writing " << DBW.getSize() << " bytes to " + << format("0x%016x", DBW.getDst()) << "\n"); + return Error::success(); } - std::error_code handleWritePtr(TargetAddress Addr, TargetAddress PtrVal) { + Error handleWritePtr(TargetAddress Addr, TargetAddress PtrVal) { DEBUG(dbgs() << " Writing pointer *" << format("0x%016x", Addr) << " = " << format("0x%016x", PtrVal) << "\n"); uintptr_t *Ptr = reinterpret_cast<uintptr_t *>(static_cast<uintptr_t>(Addr)); *Ptr = static_cast<uintptr_t>(PtrVal); - return std::error_code(); + return Error::success(); } ChannelT &Channel; SymbolLookupFtor SymbolLookup; + EHFrameRegistrationFtor EHFramesRegister, EHFramesDeregister; std::map<ResourceIdMgr::ResourceId, Allocator> Allocators; typedef std::vector<typename TargetT::IndirectStubsInfo> ISBlockOwnerList; std::map<ResourceIdMgr::ResourceId, ISBlockOwnerList> IndirectStubsOwners; diff --git a/include/llvm/ExecutionEngine/Orc/RPCChannel.h b/include/llvm/ExecutionEngine/Orc/RPCChannel.h index b97b6daf5864..c569e3cf05b4 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCChannel.h +++ b/include/llvm/ExecutionEngine/Orc/RPCChannel.h @@ -1,13 +1,27 @@ -// -*- c++ -*- +//===- llvm/ExecutionEngine/Orc/RPCChannel.h --------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// #ifndef LLVM_EXECUTIONENGINE_ORC_RPCCHANNEL_H #define LLVM_EXECUTIONENGINE_ORC_RPCCHANNEL_H #include "OrcError.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Endian.h" - -#include <system_error> +#include "llvm/Support/Error.h" +#include <cstddef> +#include <cstdint> +#include <mutex> +#include <string> +#include <tuple> +#include <vector> namespace llvm { namespace orc { @@ -19,40 +33,73 @@ public: virtual ~RPCChannel() {} /// Read Size bytes from the stream into *Dst. - virtual std::error_code readBytes(char *Dst, unsigned Size) = 0; + virtual Error readBytes(char *Dst, unsigned Size) = 0; /// Read size bytes from *Src and append them to the stream. - virtual std::error_code appendBytes(const char *Src, unsigned Size) = 0; + virtual Error appendBytes(const char *Src, unsigned Size) = 0; /// Flush the stream if possible. - virtual std::error_code send() = 0; + virtual Error send() = 0; + + /// Get the lock for stream reading. + std::mutex &getReadLock() { return readLock; } + + /// Get the lock for stream writing. + std::mutex &getWriteLock() { return writeLock; } + +private: + std::mutex readLock, writeLock; }; +/// Notify the channel that we're starting a message send. +/// Locks the channel for writing. +inline Error startSendMessage(RPCChannel &C) { + C.getWriteLock().lock(); + return Error::success(); +} + +/// Notify the channel that we're ending a message send. +/// Unlocks the channel for writing. +inline Error endSendMessage(RPCChannel &C) { + C.getWriteLock().unlock(); + return Error::success(); +} + +/// Notify the channel that we're starting a message receive. +/// Locks the channel for reading. +inline Error startReceiveMessage(RPCChannel &C) { + C.getReadLock().lock(); + return Error::success(); +} + +/// Notify the channel that we're ending a message receive. +/// Unlocks the channel for reading. +inline Error endReceiveMessage(RPCChannel &C) { + C.getReadLock().unlock(); + return Error::success(); +} + /// RPC channel serialization for a variadic list of arguments. template <typename T, typename... Ts> -std::error_code serialize_seq(RPCChannel &C, const T &Arg, const Ts &... Args) { - if (auto EC = serialize(C, Arg)) - return EC; - return serialize_seq(C, Args...); +Error serializeSeq(RPCChannel &C, const T &Arg, const Ts &... Args) { + if (auto Err = serialize(C, Arg)) + return Err; + return serializeSeq(C, Args...); } /// RPC channel serialization for an (empty) variadic list of arguments. -inline std::error_code serialize_seq(RPCChannel &C) { - return std::error_code(); -} +inline Error serializeSeq(RPCChannel &C) { return Error::success(); } /// RPC channel deserialization for a variadic list of arguments. template <typename T, typename... Ts> -std::error_code deserialize_seq(RPCChannel &C, T &Arg, Ts &... Args) { - if (auto EC = deserialize(C, Arg)) - return EC; - return deserialize_seq(C, Args...); +Error deserializeSeq(RPCChannel &C, T &Arg, Ts &... Args) { + if (auto Err = deserialize(C, Arg)) + return Err; + return deserializeSeq(C, Args...); } /// RPC channel serialization for an (empty) variadic list of arguments. -inline std::error_code deserialize_seq(RPCChannel &C) { - return std::error_code(); -} +inline Error deserializeSeq(RPCChannel &C) { return Error::success(); } /// RPC channel serialization for integer primitives. template <typename T> @@ -61,7 +108,7 @@ typename std::enable_if< std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, - std::error_code>::type + Error>::type serialize(RPCChannel &C, T V) { support::endian::byte_swap<T, support::big>(V); return C.appendBytes(reinterpret_cast<const char *>(&V), sizeof(T)); @@ -74,106 +121,129 @@ typename std::enable_if< std::is_same<T, uint32_t>::value || std::is_same<T, int32_t>::value || std::is_same<T, uint16_t>::value || std::is_same<T, int16_t>::value || std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, - std::error_code>::type + Error>::type deserialize(RPCChannel &C, T &V) { - if (auto EC = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T))) - return EC; + if (auto Err = C.readBytes(reinterpret_cast<char *>(&V), sizeof(T))) + return Err; support::endian::byte_swap<T, support::big>(V); - return std::error_code(); + return Error::success(); } /// RPC channel serialization for enums. template <typename T> -typename std::enable_if<std::is_enum<T>::value, std::error_code>::type +typename std::enable_if<std::is_enum<T>::value, Error>::type serialize(RPCChannel &C, T V) { return serialize(C, static_cast<typename std::underlying_type<T>::type>(V)); } /// RPC channel deserialization for enums. template <typename T> -typename std::enable_if<std::is_enum<T>::value, std::error_code>::type +typename std::enable_if<std::is_enum<T>::value, Error>::type deserialize(RPCChannel &C, T &V) { typename std::underlying_type<T>::type Tmp; - std::error_code EC = deserialize(C, Tmp); + Error Err = deserialize(C, Tmp); V = static_cast<T>(Tmp); - return EC; + return Err; } /// RPC channel serialization for bools. -inline std::error_code serialize(RPCChannel &C, bool V) { +inline Error serialize(RPCChannel &C, bool V) { uint8_t VN = V ? 1 : 0; return C.appendBytes(reinterpret_cast<const char *>(&VN), 1); } /// RPC channel deserialization for bools. -inline std::error_code deserialize(RPCChannel &C, bool &V) { +inline Error deserialize(RPCChannel &C, bool &V) { uint8_t VN = 0; - if (auto EC = C.readBytes(reinterpret_cast<char *>(&VN), 1)) - return EC; + if (auto Err = C.readBytes(reinterpret_cast<char *>(&VN), 1)) + return Err; - V = (VN != 0) ? true : false; - return std::error_code(); + V = (VN != 0); + return Error::success(); } /// RPC channel serialization for StringRefs. /// Note: There is no corresponding deseralization for this, as StringRef /// doesn't own its memory and so can't hold the deserialized data. -inline std::error_code serialize(RPCChannel &C, StringRef S) { - if (auto EC = serialize(C, static_cast<uint64_t>(S.size()))) - return EC; +inline Error serialize(RPCChannel &C, StringRef S) { + if (auto Err = serialize(C, static_cast<uint64_t>(S.size()))) + return Err; return C.appendBytes((const char *)S.bytes_begin(), S.size()); } /// RPC channel serialization for std::strings. -inline std::error_code serialize(RPCChannel &C, const std::string &S) { +inline Error serialize(RPCChannel &C, const std::string &S) { return serialize(C, StringRef(S)); } /// RPC channel deserialization for std::strings. -inline std::error_code deserialize(RPCChannel &C, std::string &S) { +inline Error deserialize(RPCChannel &C, std::string &S) { uint64_t Count; - if (auto EC = deserialize(C, Count)) - return EC; + if (auto Err = deserialize(C, Count)) + return Err; S.resize(Count); return C.readBytes(&S[0], Count); } +// Serialization helper for std::tuple. +template <typename TupleT, size_t... Is> +inline Error serializeTupleHelper(RPCChannel &C, const TupleT &V, + llvm::index_sequence<Is...> _) { + return serializeSeq(C, std::get<Is>(V)...); +} + +/// RPC channel serialization for std::tuple. +template <typename... ArgTs> +inline Error serialize(RPCChannel &C, const std::tuple<ArgTs...> &V) { + return serializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>()); +} + +// Serialization helper for std::tuple. +template <typename TupleT, size_t... Is> +inline Error deserializeTupleHelper(RPCChannel &C, TupleT &V, + llvm::index_sequence<Is...> _) { + return deserializeSeq(C, std::get<Is>(V)...); +} + +/// RPC channel deserialization for std::tuple. +template <typename... ArgTs> +inline Error deserialize(RPCChannel &C, std::tuple<ArgTs...> &V) { + return deserializeTupleHelper(C, V, llvm::index_sequence_for<ArgTs...>()); +} + /// RPC channel serialization for ArrayRef<T>. -template <typename T> -std::error_code serialize(RPCChannel &C, const ArrayRef<T> &A) { - if (auto EC = serialize(C, static_cast<uint64_t>(A.size()))) - return EC; +template <typename T> Error serialize(RPCChannel &C, const ArrayRef<T> &A) { + if (auto Err = serialize(C, static_cast<uint64_t>(A.size()))) + return Err; for (const auto &E : A) - if (auto EC = serialize(C, E)) - return EC; + if (auto Err = serialize(C, E)) + return Err; - return std::error_code(); + return Error::success(); } /// RPC channel serialization for std::array<T>. -template <typename T> -std::error_code serialize(RPCChannel &C, const std::vector<T> &V) { +template <typename T> Error serialize(RPCChannel &C, const std::vector<T> &V) { return serialize(C, ArrayRef<T>(V)); } /// RPC channel deserialization for std::array<T>. -template <typename T> -std::error_code deserialize(RPCChannel &C, std::vector<T> &V) { +template <typename T> Error deserialize(RPCChannel &C, std::vector<T> &V) { uint64_t Count = 0; - if (auto EC = deserialize(C, Count)) - return EC; + if (auto Err = deserialize(C, Count)) + return Err; V.resize(Count); for (auto &E : V) - if (auto EC = deserialize(C, E)) - return EC; + if (auto Err = deserialize(C, E)) + return Err; - return std::error_code(); + return Error::success(); } } // end namespace remote } // end namespace orc } // end namespace llvm -#endif +#endif // LLVM_EXECUTIONENGINE_ORC_RPCCHANNEL_H diff --git a/include/llvm/ExecutionEngine/Orc/RPCUtils.h b/include/llvm/ExecutionEngine/Orc/RPCUtils.h index 0bd5cbc0cdde..966a49684348 100644 --- a/include/llvm/ExecutionEngine/Orc/RPCUtils.h +++ b/include/llvm/ExecutionEngine/Orc/RPCUtils.h @@ -14,78 +14,256 @@ #ifndef LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H #define LLVM_EXECUTIONENGINE_ORC_RPCUTILS_H +#include <map> +#include <vector> + +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ExecutionEngine/Orc/OrcError.h" + +#ifdef _MSC_VER +// concrt.h depends on eh.h for __uncaught_exception declaration +// even if we disable exceptions. +#include <eh.h> + +// Disable warnings from ppltasks.h transitively included by <future>. +#pragma warning(push) +#pragma warning(disable : 4530) +#pragma warning(disable : 4062) +#endif + +#include <future> + +#ifdef _MSC_VER +#pragma warning(pop) +#endif namespace llvm { namespace orc { namespace remote { +/// Describes reserved RPC Function Ids. +/// +/// The default implementation will serve for integer and enum function id +/// types. If you want to use a custom type as your FunctionId you can +/// specialize this class and provide unique values for InvalidId, +/// ResponseId and FirstValidId. + +template <typename T> class RPCFunctionIdTraits { +public: + static const T InvalidId = static_cast<T>(0); + static const T ResponseId = static_cast<T>(1); + static const T FirstValidId = static_cast<T>(2); +}; + // Base class containing utilities that require partial specialization. // These cannot be included in RPC, as template class members cannot be // partially specialized. class RPCBase { protected: - template <typename ProcedureIdT, ProcedureIdT ProcId, typename... Ts> - class ProcedureHelper { + // RPC Function description type. + // + // This class provides the information and operations needed to support the + // RPC primitive operations (call, expect, etc) for a given function. It + // is specialized for void and non-void functions to deal with the differences + // betwen the two. Both specializations have the same interface: + // + // Id - The function's unique identifier. + // OptionalReturn - The return type for asyncronous calls. + // ErrorReturn - The return type for synchronous calls. + // optionalToErrorReturn - Conversion from a valid OptionalReturn to an + // ErrorReturn. + // readResult - Deserialize a result from a channel. + // abandon - Abandon a promised (asynchronous) result. + // respond - Retun a result on the channel. + template <typename FunctionIdT, FunctionIdT FuncId, typename FnT> + class FunctionHelper {}; + + // RPC Function description specialization for non-void functions. + template <typename FunctionIdT, FunctionIdT FuncId, typename RetT, + typename... ArgTs> + class FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)> { public: - static const ProcedureIdT Id = ProcId; + static_assert(FuncId != RPCFunctionIdTraits<FunctionIdT>::InvalidId && + FuncId != RPCFunctionIdTraits<FunctionIdT>::ResponseId, + "Cannot define custom function with InvalidId or ResponseId. " + "Please use RPCFunctionTraits<FunctionIdT>::FirstValidId."); + + static const FunctionIdT Id = FuncId; + + typedef Optional<RetT> OptionalReturn; + + typedef Expected<RetT> ErrorReturn; + + static ErrorReturn optionalToErrorReturn(OptionalReturn &&V) { + assert(V && "Return value not available"); + return std::move(*V); + } + + template <typename ChannelT> + static Error readResult(ChannelT &C, std::promise<OptionalReturn> &P) { + RetT Val; + auto Err = deserialize(C, Val); + auto Err2 = endReceiveMessage(C); + Err = joinErrors(std::move(Err), std::move(Err2)); + + if (Err) { + P.set_value(OptionalReturn()); + return Err; + } + P.set_value(std::move(Val)); + return Error::success(); + } + + static void abandon(std::promise<OptionalReturn> &P) { + P.set_value(OptionalReturn()); + } + + template <typename ChannelT, typename SequenceNumberT> + static Error respond(ChannelT &C, SequenceNumberT SeqNo, + ErrorReturn &Result) { + FunctionIdT ResponseId = RPCFunctionIdTraits<FunctionIdT>::ResponseId; + + // If the handler returned an error then bail out with that. + if (!Result) + return Result.takeError(); + + // Otherwise open a new message on the channel and send the result. + if (auto Err = startSendMessage(C)) + return Err; + if (auto Err = serializeSeq(C, ResponseId, SeqNo, *Result)) + return Err; + return endSendMessage(C); + } }; - template <typename ChannelT, typename Proc> class CallHelper; + // RPC Function description specialization for void functions. + template <typename FunctionIdT, FunctionIdT FuncId, typename... ArgTs> + class FunctionHelper<FunctionIdT, FuncId, void(ArgTs...)> { + public: + static_assert(FuncId != RPCFunctionIdTraits<FunctionIdT>::InvalidId && + FuncId != RPCFunctionIdTraits<FunctionIdT>::ResponseId, + "Cannot define custom function with InvalidId or ResponseId. " + "Please use RPCFunctionTraits<FunctionIdT>::FirstValidId."); - template <typename ChannelT, typename ProcedureIdT, ProcedureIdT ProcId, - typename... ArgTs> - class CallHelper<ChannelT, ProcedureHelper<ProcedureIdT, ProcId, ArgTs...>> { + static const FunctionIdT Id = FuncId; + + typedef bool OptionalReturn; + typedef Error ErrorReturn; + + static ErrorReturn optionalToErrorReturn(OptionalReturn &&V) { + assert(V && "Return value not available"); + return Error::success(); + } + + template <typename ChannelT> + static Error readResult(ChannelT &C, std::promise<OptionalReturn> &P) { + // Void functions don't have anything to deserialize, so we're good. + P.set_value(true); + return endReceiveMessage(C); + } + + static void abandon(std::promise<OptionalReturn> &P) { P.set_value(false); } + + template <typename ChannelT, typename SequenceNumberT> + static Error respond(ChannelT &C, SequenceNumberT SeqNo, + ErrorReturn &Result) { + const FunctionIdT ResponseId = + RPCFunctionIdTraits<FunctionIdT>::ResponseId; + + // If the handler returned an error then bail out with that. + if (Result) + return std::move(Result); + + // Otherwise open a new message on the channel and send the result. + if (auto Err = startSendMessage(C)) + return Err; + if (auto Err = serializeSeq(C, ResponseId, SeqNo)) + return Err; + return endSendMessage(C); + } + }; + + // Helper for the call primitive. + template <typename ChannelT, typename SequenceNumberT, typename Func> + class CallHelper; + + template <typename ChannelT, typename SequenceNumberT, typename FunctionIdT, + FunctionIdT FuncId, typename RetT, typename... ArgTs> + class CallHelper<ChannelT, SequenceNumberT, + FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)>> { public: - static std::error_code call(ChannelT &C, const ArgTs &... Args) { - if (auto EC = serialize(C, ProcId)) - return EC; - // If you see a compile-error on this line you're probably calling a - // function with the wrong signature. - return serialize_seq(C, Args...); + static Error call(ChannelT &C, SequenceNumberT SeqNo, + const ArgTs &... Args) { + if (auto Err = startSendMessage(C)) + return Err; + if (auto Err = serializeSeq(C, FuncId, SeqNo, Args...)) + return Err; + return endSendMessage(C); } }; - template <typename ChannelT, typename Proc> class HandlerHelper; + // Helper for handle primitive. + template <typename ChannelT, typename SequenceNumberT, typename Func> + class HandlerHelper; - template <typename ChannelT, typename ProcedureIdT, ProcedureIdT ProcId, - typename... ArgTs> - class HandlerHelper<ChannelT, - ProcedureHelper<ProcedureIdT, ProcId, ArgTs...>> { + template <typename ChannelT, typename SequenceNumberT, typename FunctionIdT, + FunctionIdT FuncId, typename RetT, typename... ArgTs> + class HandlerHelper<ChannelT, SequenceNumberT, + FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)>> { public: template <typename HandlerT> - static std::error_code handle(ChannelT &C, HandlerT Handler) { + static Error handle(ChannelT &C, HandlerT Handler) { return readAndHandle(C, Handler, llvm::index_sequence_for<ArgTs...>()); } private: + typedef FunctionHelper<FunctionIdT, FuncId, RetT(ArgTs...)> Func; + template <typename HandlerT, size_t... Is> - static std::error_code readAndHandle(ChannelT &C, HandlerT Handler, - llvm::index_sequence<Is...> _) { + static Error readAndHandle(ChannelT &C, HandlerT Handler, + llvm::index_sequence<Is...> _) { std::tuple<ArgTs...> RPCArgs; - if (auto EC = deserialize_seq(C, std::get<Is>(RPCArgs)...)) - return EC; - return Handler(std::get<Is>(RPCArgs)...); + SequenceNumberT SeqNo; + // GCC 4.7 and 4.8 incorrectly issue a -Wunused-but-set-variable warning + // for RPCArgs. Void cast RPCArgs to work around this for now. + // FIXME: Remove this workaround once we can assume a working GCC version. + (void)RPCArgs; + if (auto Err = deserializeSeq(C, SeqNo, std::get<Is>(RPCArgs)...)) + return Err; + + // We've deserialized the arguments, so unlock the channel for reading + // before we call the handler. This allows recursive RPC calls. + if (auto Err = endReceiveMessage(C)) + return Err; + + // Run the handler and get the result. + auto Result = Handler(std::get<Is>(RPCArgs)...); + + // Return the result to the client. + return Func::template respond<ChannelT, SequenceNumberT>(C, SeqNo, + Result); } }; - template <typename ClassT, typename... ArgTs> class MemberFnWrapper { + // Helper for wrapping member functions up as functors. + template <typename ClassT, typename RetT, typename... ArgTs> + class MemberFnWrapper { public: - typedef std::error_code (ClassT::*MethodT)(ArgTs...); + typedef RetT (ClassT::*MethodT)(ArgTs...); MemberFnWrapper(ClassT &Instance, MethodT Method) : Instance(Instance), Method(Method) {} - std::error_code operator()(ArgTs &... Args) { - return (Instance.*Method)(Args...); - } + RetT operator()(ArgTs &... Args) { return (Instance.*Method)(Args...); } private: ClassT &Instance; MethodT Method; }; + // Helper that provides a Functor for deserializing arguments. template <typename... ArgTs> class ReadArgs { public: - std::error_code operator()() { return std::error_code(); } + Error operator()() { return Error::success(); } }; template <typename ArgT, typename... ArgTs> @@ -94,7 +272,7 @@ protected: ReadArgs(ArgT &Arg, ArgTs &... Args) : ReadArgs<ArgTs...>(Args...), Arg(Arg) {} - std::error_code operator()(ArgT &ArgVal, ArgTs &... ArgVals) { + Error operator()(ArgT &ArgVal, ArgTs &... ArgVals) { this->Arg = std::move(ArgVal); return ReadArgs<ArgTs...>::operator()(ArgVals...); } @@ -106,7 +284,7 @@ protected: /// Contains primitive utilities for defining, calling and handling calls to /// remote procedures. ChannelT is a bidirectional stream conforming to the -/// RPCChannel interface (see RPCChannel.h), and ProcedureIdT is a procedure +/// RPCChannel interface (see RPCChannel.h), and FunctionIdT is a procedure /// identifier type that must be serializable on ChannelT. /// /// These utilities support the construction of very primitive RPC utilities. @@ -123,120 +301,223 @@ protected: /// /// Overview (see comments individual types/methods for details): /// -/// Procedure<Id, Args...> : +/// Function<Id, Args...> : /// /// associates a unique serializable id with an argument list. /// /// -/// call<Proc>(Channel, Args...) : +/// call<Func>(Channel, Args...) : /// -/// Calls the remote procedure 'Proc' by serializing Proc's id followed by its +/// Calls the remote procedure 'Func' by serializing Func's id followed by its /// arguments and sending the resulting bytes to 'Channel'. /// /// -/// handle<Proc>(Channel, <functor matching std::error_code(Args...)> : +/// handle<Func>(Channel, <functor matching Error(Args...)> : /// -/// Handles a call to 'Proc' by deserializing its arguments and calling the -/// given functor. This assumes that the id for 'Proc' has already been +/// Handles a call to 'Func' by deserializing its arguments and calling the +/// given functor. This assumes that the id for 'Func' has already been /// deserialized. /// -/// expect<Proc>(Channel, <functor matching std::error_code(Args...)> : +/// expect<Func>(Channel, <functor matching Error(Args...)> : /// /// The same as 'handle', except that the procedure id should not have been -/// read yet. Expect will deserialize the id and assert that it matches Proc's +/// read yet. Expect will deserialize the id and assert that it matches Func's /// id. If it does not, and unexpected RPC call error is returned. - -template <typename ChannelT, typename ProcedureIdT = uint32_t> +template <typename ChannelT, typename FunctionIdT = uint32_t, + typename SequenceNumberT = uint16_t> class RPC : public RPCBase { public: + /// RPC default constructor. + RPC() = default; + + /// RPC instances cannot be copied. + RPC(const RPC &) = delete; + + /// RPC instances cannot be copied. + RPC &operator=(const RPC &) = delete; + + /// RPC move constructor. + // FIXME: Remove once MSVC can synthesize move ops. + RPC(RPC &&Other) + : SequenceNumberMgr(std::move(Other.SequenceNumberMgr)), + OutstandingResults(std::move(Other.OutstandingResults)) {} + + /// RPC move assignment. + // FIXME: Remove once MSVC can synthesize move ops. + RPC &operator=(RPC &&Other) { + SequenceNumberMgr = std::move(Other.SequenceNumberMgr); + OutstandingResults = std::move(Other.OutstandingResults); + return *this; + } + /// Utility class for defining/referring to RPC procedures. /// /// Typedefs of this utility are used when calling/handling remote procedures. /// - /// ProcId should be a unique value of ProcedureIdT (i.e. not used with any - /// other Procedure typedef in the RPC API being defined. + /// FuncId should be a unique value of FunctionIdT (i.e. not used with any + /// other Function typedef in the RPC API being defined. /// /// the template argument Ts... gives the argument list for the remote /// procedure. /// /// E.g. /// - /// typedef Procedure<0, bool> Proc1; - /// typedef Procedure<1, std::string, std::vector<int>> Proc2; + /// typedef Function<0, bool> Func1; + /// typedef Function<1, std::string, std::vector<int>> Func2; /// - /// if (auto EC = call<Proc1>(Channel, true)) - /// /* handle EC */; + /// if (auto Err = call<Func1>(Channel, true)) + /// /* handle Err */; /// - /// if (auto EC = expect<Proc2>(Channel, + /// if (auto Err = expect<Func2>(Channel, /// [](std::string &S, std::vector<int> &V) { /// // Stuff. - /// return std::error_code(); + /// return Error::success(); /// }) - /// /* handle EC */; + /// /* handle Err */; /// - template <ProcedureIdT ProcId, typename... Ts> - using Procedure = ProcedureHelper<ProcedureIdT, ProcId, Ts...>; + template <FunctionIdT FuncId, typename FnT> + using Function = FunctionHelper<FunctionIdT, FuncId, FnT>; + + /// Return type for asynchronous call primitives. + template <typename Func> + using AsyncCallResult = std::future<typename Func::OptionalReturn>; + + /// Return type for asynchronous call-with-seq primitives. + template <typename Func> + using AsyncCallWithSeqResult = + std::pair<std::future<typename Func::OptionalReturn>, SequenceNumberT>; /// Serialize Args... to channel C, but do not call C.send(). /// - /// For buffered channels, this can be used to queue up several calls before - /// flushing the channel. - template <typename Proc, typename... ArgTs> - static std::error_code appendCall(ChannelT &C, const ArgTs &... Args) { - return CallHelper<ChannelT, Proc>::call(C, Args...); + /// Returns an error (on serialization failure) or a pair of: + /// (1) A future Optional<T> (or future<bool> for void functions), and + /// (2) A sequence number. + /// + /// This utility function is primarily used for single-threaded mode support, + /// where the sequence number can be used to wait for the corresponding + /// result. In multi-threaded mode the appendCallAsync method, which does not + /// return the sequence numeber, should be preferred. + template <typename Func, typename... ArgTs> + Expected<AsyncCallWithSeqResult<Func>> + appendCallAsyncWithSeq(ChannelT &C, const ArgTs &... Args) { + auto SeqNo = SequenceNumberMgr.getSequenceNumber(); + std::promise<typename Func::OptionalReturn> Promise; + auto Result = Promise.get_future(); + OutstandingResults[SeqNo] = + createOutstandingResult<Func>(std::move(Promise)); + + if (auto Err = CallHelper<ChannelT, SequenceNumberT, Func>::call(C, SeqNo, + Args...)) { + abandonOutstandingResults(); + return std::move(Err); + } else + return AsyncCallWithSeqResult<Func>(std::move(Result), SeqNo); } - /// Serialize Args... to channel C and call C.send(). - template <typename Proc, typename... ArgTs> - static std::error_code call(ChannelT &C, const ArgTs &... Args) { - if (auto EC = appendCall<Proc>(C, Args...)) - return EC; - return C.send(); + /// The same as appendCallAsyncWithSeq, except that it calls C.send() to + /// flush the channel after serializing the call. + template <typename Func, typename... ArgTs> + Expected<AsyncCallWithSeqResult<Func>> + callAsyncWithSeq(ChannelT &C, const ArgTs &... Args) { + auto Result = appendCallAsyncWithSeq<Func>(C, Args...); + if (!Result) + return Result; + if (auto Err = C.send()) { + abandonOutstandingResults(); + return std::move(Err); + } + return Result; + } + + /// Serialize Args... to channel C, but do not call send. + /// Returns an error if serialization fails, otherwise returns a + /// std::future<Optional<T>> (or a future<bool> for void functions). + template <typename Func, typename... ArgTs> + Expected<AsyncCallResult<Func>> appendCallAsync(ChannelT &C, + const ArgTs &... Args) { + auto ResAndSeqOrErr = appendCallAsyncWithSeq<Func>(C, Args...); + if (ResAndSeqOrErr) + return std::move(ResAndSeqOrErr->first); + return ResAndSeqOrErr.getError(); + } + + /// The same as appendCallAsync, except that it calls C.send to flush the + /// channel after serializing the call. + template <typename Func, typename... ArgTs> + Expected<AsyncCallResult<Func>> callAsync(ChannelT &C, + const ArgTs &... Args) { + auto ResAndSeqOrErr = callAsyncWithSeq<Func>(C, Args...); + if (ResAndSeqOrErr) + return std::move(ResAndSeqOrErr->first); + return ResAndSeqOrErr.getError(); + } + + /// This can be used in single-threaded mode. + template <typename Func, typename HandleFtor, typename... ArgTs> + typename Func::ErrorReturn + callSTHandling(ChannelT &C, HandleFtor &HandleOther, const ArgTs &... Args) { + if (auto ResultAndSeqNoOrErr = callAsyncWithSeq<Func>(C, Args...)) { + auto &ResultAndSeqNo = *ResultAndSeqNoOrErr; + if (auto Err = waitForResult(C, ResultAndSeqNo.second, HandleOther)) + return std::move(Err); + return Func::optionalToErrorReturn(ResultAndSeqNo.first.get()); + } else + return ResultAndSeqNoOrErr.takeError(); } - /// Deserialize and return an enum whose underlying type is ProcedureIdT. - static std::error_code getNextProcId(ChannelT &C, ProcedureIdT &Id) { + // This can be used in single-threaded mode. + template <typename Func, typename... ArgTs> + typename Func::ErrorReturn callST(ChannelT &C, const ArgTs &... Args) { + return callSTHandling<Func>(C, handleNone, Args...); + } + + /// Start receiving a new function call. + /// + /// Calls startReceiveMessage on the channel, then deserializes a FunctionId + /// into Id. + Error startReceivingFunction(ChannelT &C, FunctionIdT &Id) { + if (auto Err = startReceiveMessage(C)) + return Err; + return deserialize(C, Id); } - /// Deserialize args for Proc from C and call Handler. The signature of - /// handler must conform to 'std::error_code(Args...)' where Args... matches - /// the arguments used in the Proc typedef. - template <typename Proc, typename HandlerT> - static std::error_code handle(ChannelT &C, HandlerT Handler) { - return HandlerHelper<ChannelT, Proc>::handle(C, Handler); + /// Deserialize args for Func from C and call Handler. The signature of + /// handler must conform to 'Error(Args...)' where Args... matches + /// the arguments used in the Func typedef. + template <typename Func, typename HandlerT> + static Error handle(ChannelT &C, HandlerT Handler) { + return HandlerHelper<ChannelT, SequenceNumberT, Func>::handle(C, Handler); } /// Helper version of 'handle' for calling member functions. - template <typename Proc, typename ClassT, typename... ArgTs> - static std::error_code - handle(ChannelT &C, ClassT &Instance, - std::error_code (ClassT::*HandlerMethod)(ArgTs...)) { - return handle<Proc>( - C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod)); + template <typename Func, typename ClassT, typename RetT, typename... ArgTs> + static Error handle(ChannelT &C, ClassT &Instance, + RetT (ClassT::*HandlerMethod)(ArgTs...)) { + return handle<Func>( + C, MemberFnWrapper<ClassT, RetT, ArgTs...>(Instance, HandlerMethod)); } - /// Deserialize a ProcedureIdT from C and verify it matches the id for Proc. + /// Deserialize a FunctionIdT from C and verify it matches the id for Func. /// If the id does match, deserialize the arguments and call the handler /// (similarly to handle). /// If the id does not match, return an unexpect RPC call error and do not /// deserialize any further bytes. - template <typename Proc, typename HandlerT> - static std::error_code expect(ChannelT &C, HandlerT Handler) { - ProcedureIdT ProcId; - if (auto EC = getNextProcId(C, ProcId)) - return EC; - if (ProcId != Proc::Id) + template <typename Func, typename HandlerT> + Error expect(ChannelT &C, HandlerT Handler) { + FunctionIdT FuncId; + if (auto Err = startReceivingFunction(C, FuncId)) + return std::move(Err); + if (FuncId != Func::Id) return orcError(OrcErrorCode::UnexpectedRPCCall); - return handle<Proc>(C, Handler); + return handle<Func>(C, Handler); } /// Helper version of expect for calling member functions. - template <typename Proc, typename ClassT, typename... ArgTs> - static std::error_code - expect(ChannelT &C, ClassT &Instance, - std::error_code (ClassT::*HandlerMethod)(ArgTs...)) { - return expect<Proc>( + template <typename Func, typename ClassT, typename... ArgTs> + static Error expect(ChannelT &C, ClassT &Instance, + Error (ClassT::*HandlerMethod)(ArgTs...)) { + return expect<Func>( C, MemberFnWrapper<ClassT, ArgTs...>(Instance, HandlerMethod)); } @@ -245,18 +526,165 @@ public: /// channel. /// E.g. /// - /// typedef Procedure<0, bool, int> Proc1; + /// typedef Function<0, bool, int> Func1; /// /// ... /// bool B; /// int I; - /// if (auto EC = expect<Proc1>(Channel, readArgs(B, I))) + /// if (auto Err = expect<Func1>(Channel, readArgs(B, I))) /// /* Handle Args */ ; /// template <typename... ArgTs> static ReadArgs<ArgTs...> readArgs(ArgTs &... Args) { return ReadArgs<ArgTs...>(Args...); } + + /// Read a response from Channel. + /// This should be called from the receive loop to retrieve results. + Error handleResponse(ChannelT &C, SequenceNumberT *SeqNoRet = nullptr) { + SequenceNumberT SeqNo; + if (auto Err = deserialize(C, SeqNo)) { + abandonOutstandingResults(); + return Err; + } + + if (SeqNoRet) + *SeqNoRet = SeqNo; + + auto I = OutstandingResults.find(SeqNo); + if (I == OutstandingResults.end()) { + abandonOutstandingResults(); + return orcError(OrcErrorCode::UnexpectedRPCResponse); + } + + if (auto Err = I->second->readResult(C)) { + abandonOutstandingResults(); + // FIXME: Release sequence numbers? + return Err; + } + + OutstandingResults.erase(I); + SequenceNumberMgr.releaseSequenceNumber(SeqNo); + + return Error::success(); + } + + // Loop waiting for a result with the given sequence number. + // This can be used as a receive loop if the user doesn't have a default. + template <typename HandleOtherFtor> + Error waitForResult(ChannelT &C, SequenceNumberT TgtSeqNo, + HandleOtherFtor &HandleOther = handleNone) { + bool GotTgtResult = false; + + while (!GotTgtResult) { + FunctionIdT Id = RPCFunctionIdTraits<FunctionIdT>::InvalidId; + if (auto Err = startReceivingFunction(C, Id)) + return Err; + if (Id == RPCFunctionIdTraits<FunctionIdT>::ResponseId) { + SequenceNumberT SeqNo; + if (auto Err = handleResponse(C, &SeqNo)) + return Err; + GotTgtResult = (SeqNo == TgtSeqNo); + } else if (auto Err = HandleOther(C, Id)) + return Err; + } + + return Error::success(); + } + + // Default handler for 'other' (non-response) functions when waiting for a + // result from the channel. + static Error handleNone(ChannelT &, FunctionIdT) { + return orcError(OrcErrorCode::UnexpectedRPCCall); + }; + +private: + // Manage sequence numbers. + class SequenceNumberManager { + public: + SequenceNumberManager() = default; + + SequenceNumberManager(const SequenceNumberManager &) = delete; + SequenceNumberManager &operator=(const SequenceNumberManager &) = delete; + + SequenceNumberManager(SequenceNumberManager &&Other) + : NextSequenceNumber(std::move(Other.NextSequenceNumber)), + FreeSequenceNumbers(std::move(Other.FreeSequenceNumbers)) {} + + SequenceNumberManager &operator=(SequenceNumberManager &&Other) { + NextSequenceNumber = std::move(Other.NextSequenceNumber); + FreeSequenceNumbers = std::move(Other.FreeSequenceNumbers); + } + + void reset() { + std::lock_guard<std::mutex> Lock(SeqNoLock); + NextSequenceNumber = 0; + FreeSequenceNumbers.clear(); + } + + SequenceNumberT getSequenceNumber() { + std::lock_guard<std::mutex> Lock(SeqNoLock); + if (FreeSequenceNumbers.empty()) + return NextSequenceNumber++; + auto SequenceNumber = FreeSequenceNumbers.back(); + FreeSequenceNumbers.pop_back(); + return SequenceNumber; + } + + void releaseSequenceNumber(SequenceNumberT SequenceNumber) { + std::lock_guard<std::mutex> Lock(SeqNoLock); + FreeSequenceNumbers.push_back(SequenceNumber); + } + + private: + std::mutex SeqNoLock; + SequenceNumberT NextSequenceNumber = 0; + std::vector<SequenceNumberT> FreeSequenceNumbers; + }; + + // Base class for results that haven't been returned from the other end of the + // RPC connection yet. + class OutstandingResult { + public: + virtual ~OutstandingResult() {} + virtual Error readResult(ChannelT &C) = 0; + virtual void abandon() = 0; + }; + + // Outstanding results for a specific function. + template <typename Func> + class OutstandingResultImpl : public OutstandingResult { + private: + public: + OutstandingResultImpl(std::promise<typename Func::OptionalReturn> &&P) + : P(std::move(P)) {} + + Error readResult(ChannelT &C) override { return Func::readResult(C, P); } + + void abandon() override { Func::abandon(P); } + + private: + std::promise<typename Func::OptionalReturn> P; + }; + + // Create an outstanding result for the given function. + template <typename Func> + std::unique_ptr<OutstandingResult> + createOutstandingResult(std::promise<typename Func::OptionalReturn> &&P) { + return llvm::make_unique<OutstandingResultImpl<Func>>(std::move(P)); + } + + // Abandon all outstanding results. + void abandonOutstandingResults() { + for (auto &KV : OutstandingResults) + KV.second->abandon(); + OutstandingResults.clear(); + SequenceNumberMgr.reset(); + } + + SequenceNumberManager SequenceNumberMgr; + std::map<SequenceNumberT, std::unique_ptr<OutstandingResult>> + OutstandingResults; }; } // end namespace remote |