diff options
Diffstat (limited to 'llvm/include/llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h')
-rw-r--r-- | llvm/include/llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h | 703 |
1 files changed, 703 insertions, 0 deletions
diff --git a/llvm/include/llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h b/llvm/include/llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h new file mode 100644 index 000000000000..9c69a84f4c67 --- /dev/null +++ b/llvm/include/llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h @@ -0,0 +1,703 @@ +//===- llvm/ExecutionEngine/Orc/RPC/RPCSerialization.h --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H +#define LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H + +#include "llvm/ExecutionEngine/Orc/OrcError.h" +#include "llvm/Support/thread.h" +#include <map> +#include <mutex> +#include <set> +#include <sstream> +#include <string> +#include <vector> + +namespace llvm { +namespace orc { +namespace rpc { + +template <typename T> +class RPCTypeName; + +/// TypeNameSequence is a utility for rendering sequences of types to a string +/// by rendering each type, separated by ", ". +template <typename... ArgTs> class RPCTypeNameSequence {}; + +/// Render an empty TypeNameSequence to an ostream. +template <typename OStream> +OStream &operator<<(OStream &OS, const RPCTypeNameSequence<> &V) { + return OS; +} + +/// Render a TypeNameSequence of a single type to an ostream. +template <typename OStream, typename ArgT> +OStream &operator<<(OStream &OS, const RPCTypeNameSequence<ArgT> &V) { + OS << RPCTypeName<ArgT>::getName(); + return OS; +} + +/// Render a TypeNameSequence of more than one type to an ostream. +template <typename OStream, typename ArgT1, typename ArgT2, typename... ArgTs> +OStream& +operator<<(OStream &OS, const RPCTypeNameSequence<ArgT1, ArgT2, ArgTs...> &V) { + OS << RPCTypeName<ArgT1>::getName() << ", " + << RPCTypeNameSequence<ArgT2, ArgTs...>(); + return OS; +} + +template <> +class RPCTypeName<void> { +public: + static const char* getName() { return "void"; } +}; + +template <> +class RPCTypeName<int8_t> { +public: + static const char* getName() { return "int8_t"; } +}; + +template <> +class RPCTypeName<uint8_t> { +public: + static const char* getName() { return "uint8_t"; } +}; + +template <> +class RPCTypeName<int16_t> { +public: + static const char* getName() { return "int16_t"; } +}; + +template <> +class RPCTypeName<uint16_t> { +public: + static const char* getName() { return "uint16_t"; } +}; + +template <> +class RPCTypeName<int32_t> { +public: + static const char* getName() { return "int32_t"; } +}; + +template <> +class RPCTypeName<uint32_t> { +public: + static const char* getName() { return "uint32_t"; } +}; + +template <> +class RPCTypeName<int64_t> { +public: + static const char* getName() { return "int64_t"; } +}; + +template <> +class RPCTypeName<uint64_t> { +public: + static const char* getName() { return "uint64_t"; } +}; + +template <> +class RPCTypeName<bool> { +public: + static const char* getName() { return "bool"; } +}; + +template <> +class RPCTypeName<std::string> { +public: + static const char* getName() { return "std::string"; } +}; + +template <> +class RPCTypeName<Error> { +public: + static const char* getName() { return "Error"; } +}; + +template <typename T> +class RPCTypeName<Expected<T>> { +public: + static const char* getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) << "Expected<" + << RPCTypeNameSequence<T>() + << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename T1, typename T2> +class RPCTypeName<std::pair<T1, T2>> { +public: + static const char* getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) << "std::pair<" << RPCTypeNameSequence<T1, T2>() + << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename... ArgTs> +class RPCTypeName<std::tuple<ArgTs...>> { +public: + static const char* getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) << "std::tuple<" + << RPCTypeNameSequence<ArgTs...>() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename T> +class RPCTypeName<std::vector<T>> { +public: + static const char*getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) << "std::vector<" << RPCTypeName<T>::getName() + << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename T> class RPCTypeName<std::set<T>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::set<" << RPCTypeName<T>::getName() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +template <typename K, typename V> class RPCTypeName<std::map<K, V>> { +public: + static const char *getName() { + static std::string Name = [] { + std::string Name; + raw_string_ostream(Name) + << "std::map<" << RPCTypeNameSequence<K, V>() << ">"; + return Name; + }(); + return Name.data(); + } +}; + +/// The SerializationTraits<ChannelT, T> class describes how to serialize and +/// deserialize an instance of type T to/from an abstract channel of type +/// ChannelT. It also provides a representation of the type's name via the +/// getName method. +/// +/// Specializations of this class should provide the following functions: +/// +/// @code{.cpp} +/// +/// static const char* getName(); +/// static Error serialize(ChannelT&, const T&); +/// static Error deserialize(ChannelT&, T&); +/// +/// @endcode +/// +/// The third argument of SerializationTraits is intended to support SFINAE. +/// E.g.: +/// +/// @code{.cpp} +/// +/// class MyVirtualChannel { ... }; +/// +/// template <DerivedChannelT> +/// class SerializationTraits<DerivedChannelT, bool, +/// typename std::enable_if< +/// std::is_base_of<VirtChannel, DerivedChannel>::value +/// >::type> { +/// public: +/// static const char* getName() { ... }; +/// } +/// +/// @endcode +template <typename ChannelT, typename WireType, + typename ConcreteType = WireType, typename = void> +class SerializationTraits; + +template <typename ChannelT> +class SequenceTraits { +public: + static Error emitSeparator(ChannelT &C) { return Error::success(); } + static Error consumeSeparator(ChannelT &C) { return Error::success(); } +}; + +/// Utility class for serializing sequences of values of varying types. +/// Specializations of this class contain 'serialize' and 'deserialize' methods +/// for the given channel. The ArgTs... list will determine the "over-the-wire" +/// types to be serialized. The serialize and deserialize methods take a list +/// CArgTs... ("caller arg types") which must be the same length as ArgTs..., +/// but may be different types from ArgTs, provided that for each CArgT there +/// is a SerializationTraits specialization +/// SerializeTraits<ChannelT, ArgT, CArgT> with methods that can serialize the +/// caller argument to over-the-wire value. +template <typename ChannelT, typename... ArgTs> +class SequenceSerialization; + +template <typename ChannelT> +class SequenceSerialization<ChannelT> { +public: + static Error serialize(ChannelT &C) { return Error::success(); } + static Error deserialize(ChannelT &C) { return Error::success(); } +}; + +template <typename ChannelT, typename ArgT> +class SequenceSerialization<ChannelT, ArgT> { +public: + + template <typename CArgT> + static Error serialize(ChannelT &C, CArgT &&CArg) { + return SerializationTraits<ChannelT, ArgT, + typename std::decay<CArgT>::type>:: + serialize(C, std::forward<CArgT>(CArg)); + } + + template <typename CArgT> + static Error deserialize(ChannelT &C, CArgT &CArg) { + return SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg); + } +}; + +template <typename ChannelT, typename ArgT, typename... ArgTs> +class SequenceSerialization<ChannelT, ArgT, ArgTs...> { +public: + + template <typename CArgT, typename... CArgTs> + static Error serialize(ChannelT &C, CArgT &&CArg, + CArgTs &&... CArgs) { + if (auto Err = + SerializationTraits<ChannelT, ArgT, typename std::decay<CArgT>::type>:: + serialize(C, std::forward<CArgT>(CArg))) + return Err; + if (auto Err = SequenceTraits<ChannelT>::emitSeparator(C)) + return Err; + return SequenceSerialization<ChannelT, ArgTs...>:: + serialize(C, std::forward<CArgTs>(CArgs)...); + } + + template <typename CArgT, typename... CArgTs> + static Error deserialize(ChannelT &C, CArgT &CArg, + CArgTs &... CArgs) { + if (auto Err = + SerializationTraits<ChannelT, ArgT, CArgT>::deserialize(C, CArg)) + return Err; + if (auto Err = SequenceTraits<ChannelT>::consumeSeparator(C)) + return Err; + return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, CArgs...); + } +}; + +template <typename ChannelT, typename... ArgTs> +Error serializeSeq(ChannelT &C, ArgTs &&... Args) { + return SequenceSerialization<ChannelT, typename std::decay<ArgTs>::type...>:: + serialize(C, std::forward<ArgTs>(Args)...); +} + +template <typename ChannelT, typename... ArgTs> +Error deserializeSeq(ChannelT &C, ArgTs &... Args) { + return SequenceSerialization<ChannelT, ArgTs...>::deserialize(C, Args...); +} + +template <typename ChannelT> +class SerializationTraits<ChannelT, Error> { +public: + + using WrappedErrorSerializer = + std::function<Error(ChannelT &C, const ErrorInfoBase&)>; + + using WrappedErrorDeserializer = + std::function<Error(ChannelT &C, Error &Err)>; + + template <typename ErrorInfoT, typename SerializeFtor, + typename DeserializeFtor> + static void registerErrorType(std::string Name, SerializeFtor Serialize, + DeserializeFtor Deserialize) { + assert(!Name.empty() && + "The empty string is reserved for the Success value"); + + const std::string *KeyName = nullptr; + { + // We're abusing the stability of std::map here: We take a reference to the + // key of the deserializers map to save us from duplicating the string in + // the serializer. This should be changed to use a stringpool if we switch + // to a map type that may move keys in memory. + std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); + auto I = + Deserializers.insert(Deserializers.begin(), + std::make_pair(std::move(Name), + std::move(Deserialize))); + KeyName = &I->first; + } + + { + assert(KeyName != nullptr && "No keyname pointer"); + std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); + Serializers[ErrorInfoT::classID()] = + [KeyName, Serialize = std::move(Serialize)]( + ChannelT &C, const ErrorInfoBase &EIB) -> Error { + assert(EIB.dynamicClassID() == ErrorInfoT::classID() && + "Serializer called for wrong error type"); + if (auto Err = serializeSeq(C, *KeyName)) + return Err; + return Serialize(C, static_cast<const ErrorInfoT &>(EIB)); + }; + } + } + + static Error serialize(ChannelT &C, Error &&Err) { + std::lock_guard<std::recursive_mutex> Lock(SerializersMutex); + + if (!Err) + return serializeSeq(C, std::string()); + + return handleErrors(std::move(Err), + [&C](const ErrorInfoBase &EIB) { + auto SI = Serializers.find(EIB.dynamicClassID()); + if (SI == Serializers.end()) + return serializeAsStringError(C, EIB); + return (SI->second)(C, EIB); + }); + } + + static Error deserialize(ChannelT &C, Error &Err) { + std::lock_guard<std::recursive_mutex> Lock(DeserializersMutex); + + std::string Key; + if (auto Err = deserializeSeq(C, Key)) + return Err; + + if (Key.empty()) { + ErrorAsOutParameter EAO(&Err); + Err = Error::success(); + return Error::success(); + } + + auto DI = Deserializers.find(Key); + assert(DI != Deserializers.end() && "No deserializer for error type"); + return (DI->second)(C, Err); + } + +private: + + static Error serializeAsStringError(ChannelT &C, const ErrorInfoBase &EIB) { + std::string ErrMsg; + { + raw_string_ostream ErrMsgStream(ErrMsg); + EIB.log(ErrMsgStream); + } + return serialize(C, make_error<StringError>(std::move(ErrMsg), + inconvertibleErrorCode())); + } + + static std::recursive_mutex SerializersMutex; + static std::recursive_mutex DeserializersMutex; + static std::map<const void*, WrappedErrorSerializer> Serializers; + static std::map<std::string, WrappedErrorDeserializer> Deserializers; +}; + +template <typename ChannelT> +std::recursive_mutex SerializationTraits<ChannelT, Error>::SerializersMutex; + +template <typename ChannelT> +std::recursive_mutex SerializationTraits<ChannelT, Error>::DeserializersMutex; + +template <typename ChannelT> +std::map<const void*, + typename SerializationTraits<ChannelT, Error>::WrappedErrorSerializer> +SerializationTraits<ChannelT, Error>::Serializers; + +template <typename ChannelT> +std::map<std::string, + typename SerializationTraits<ChannelT, Error>::WrappedErrorDeserializer> +SerializationTraits<ChannelT, Error>::Deserializers; + +/// Registers a serializer and deserializer for the given error type on the +/// given channel type. +template <typename ChannelT, typename ErrorInfoT, typename SerializeFtor, + typename DeserializeFtor> +void registerErrorSerialization(std::string Name, SerializeFtor &&Serialize, + DeserializeFtor &&Deserialize) { + SerializationTraits<ChannelT, Error>::template registerErrorType<ErrorInfoT>( + std::move(Name), + std::forward<SerializeFtor>(Serialize), + std::forward<DeserializeFtor>(Deserialize)); +} + +/// Registers serialization/deserialization for StringError. +template <typename ChannelT> +void registerStringError() { + static bool AlreadyRegistered = false; + if (!AlreadyRegistered) { + registerErrorSerialization<ChannelT, StringError>( + "StringError", + [](ChannelT &C, const StringError &SE) { + return serializeSeq(C, SE.getMessage()); + }, + [](ChannelT &C, Error &Err) -> Error { + ErrorAsOutParameter EAO(&Err); + std::string Msg; + if (auto E2 = deserializeSeq(C, Msg)) + return E2; + Err = + make_error<StringError>(std::move(Msg), + orcError( + OrcErrorCode::UnknownErrorCodeFromRemote)); + return Error::success(); + }); + AlreadyRegistered = true; + } +} + +/// SerializationTraits for Expected<T1> from an Expected<T2>. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, Expected<T2>> { +public: + + static Error serialize(ChannelT &C, Expected<T2> &&ValOrErr) { + if (ValOrErr) { + if (auto Err = serializeSeq(C, true)) + return Err; + return SerializationTraits<ChannelT, T1, T2>::serialize(C, *ValOrErr); + } + if (auto Err = serializeSeq(C, false)) + return Err; + return serializeSeq(C, ValOrErr.takeError()); + } + + static Error deserialize(ChannelT &C, Expected<T2> &ValOrErr) { + ExpectedAsOutParameter<T2> EAO(&ValOrErr); + bool HasValue; + if (auto Err = deserializeSeq(C, HasValue)) + return Err; + if (HasValue) + return SerializationTraits<ChannelT, T1, T2>::deserialize(C, *ValOrErr); + Error Err = Error::success(); + if (auto E2 = deserializeSeq(C, Err)) + return E2; + ValOrErr = std::move(Err); + return Error::success(); + } +}; + +/// SerializationTraits for Expected<T1> from a T2. +template <typename ChannelT, typename T1, typename T2> +class SerializationTraits<ChannelT, Expected<T1>, T2> { +public: + + static Error serialize(ChannelT &C, T2 &&Val) { + return serializeSeq(C, Expected<T2>(std::forward<T2>(Val))); + } +}; + +/// SerializationTraits for Expected<T1> from an Error. +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, Expected<T>, Error> { +public: + + static Error serialize(ChannelT &C, Error &&Err) { + return serializeSeq(C, Expected<T>(std::move(Err))); + } +}; + +/// SerializationTraits default specialization for std::pair. +template <typename ChannelT, typename T1, typename T2, typename T3, typename T4> +class SerializationTraits<ChannelT, std::pair<T1, T2>, std::pair<T3, T4>> { +public: + static Error serialize(ChannelT &C, const std::pair<T3, T4> &V) { + if (auto Err = SerializationTraits<ChannelT, T1, T3>::serialize(C, V.first)) + return Err; + return SerializationTraits<ChannelT, T2, T4>::serialize(C, V.second); + } + + static Error deserialize(ChannelT &C, std::pair<T3, T4> &V) { + if (auto Err = + SerializationTraits<ChannelT, T1, T3>::deserialize(C, V.first)) + return Err; + return SerializationTraits<ChannelT, T2, T4>::deserialize(C, V.second); + } +}; + +/// SerializationTraits default specialization for std::tuple. +template <typename ChannelT, typename... ArgTs> +class SerializationTraits<ChannelT, std::tuple<ArgTs...>> { +public: + + /// RPC channel serialization for std::tuple. + static Error serialize(ChannelT &C, const std::tuple<ArgTs...> &V) { + return serializeTupleHelper(C, V, std::index_sequence_for<ArgTs...>()); + } + + /// RPC channel deserialization for std::tuple. + static Error deserialize(ChannelT &C, std::tuple<ArgTs...> &V) { + return deserializeTupleHelper(C, V, std::index_sequence_for<ArgTs...>()); + } + +private: + // Serialization helper for std::tuple. + template <size_t... Is> + static Error serializeTupleHelper(ChannelT &C, const std::tuple<ArgTs...> &V, + std::index_sequence<Is...> _) { + return serializeSeq(C, std::get<Is>(V)...); + } + + // Serialization helper for std::tuple. + template <size_t... Is> + static Error deserializeTupleHelper(ChannelT &C, std::tuple<ArgTs...> &V, + std::index_sequence<Is...> _) { + return deserializeSeq(C, std::get<Is>(V)...); + } +}; + +/// SerializationTraits default specialization for std::vector. +template <typename ChannelT, typename T> +class SerializationTraits<ChannelT, std::vector<T>> { +public: + + /// Serialize a std::vector<T> from std::vector<T>. + static Error serialize(ChannelT &C, const std::vector<T> &V) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(V.size()))) + return Err; + + for (const auto &E : V) + if (auto Err = serializeSeq(C, E)) + return Err; + + return Error::success(); + } + + /// Deserialize a std::vector<T> to a std::vector<T>. + static Error deserialize(ChannelT &C, std::vector<T> &V) { + assert(V.empty() && + "Expected default-constructed vector to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + V.resize(Count); + for (auto &E : V) + if (auto Err = deserializeSeq(C, E)) + return Err; + + return Error::success(); + } +}; + +template <typename ChannelT, typename T, typename T2> +class SerializationTraits<ChannelT, std::set<T>, std::set<T2>> { +public: + /// Serialize a std::set<T> from std::set<T2>. + static Error serialize(ChannelT &C, const std::set<T2> &S) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(S.size()))) + return Err; + + for (const auto &E : S) + if (auto Err = SerializationTraits<ChannelT, T, T2>::serialize(C, E)) + return Err; + + return Error::success(); + } + + /// Deserialize a std::set<T> to a std::set<T>. + static Error deserialize(ChannelT &C, std::set<T2> &S) { + assert(S.empty() && "Expected default-constructed set to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + while (Count-- != 0) { + T2 Val; + if (auto Err = SerializationTraits<ChannelT, T, T2>::deserialize(C, Val)) + return Err; + + auto Added = S.insert(Val).second; + if (!Added) + return make_error<StringError>("Duplicate element in deserialized set", + orcError(OrcErrorCode::UnknownORCError)); + } + + return Error::success(); + } +}; + +template <typename ChannelT, typename K, typename V, typename K2, typename V2> +class SerializationTraits<ChannelT, std::map<K, V>, std::map<K2, V2>> { +public: + /// Serialize a std::map<K, V> from std::map<K2, V2>. + static Error serialize(ChannelT &C, const std::map<K2, V2> &M) { + if (auto Err = serializeSeq(C, static_cast<uint64_t>(M.size()))) + return Err; + + for (const auto &E : M) { + if (auto Err = + SerializationTraits<ChannelT, K, K2>::serialize(C, E.first)) + return Err; + if (auto Err = + SerializationTraits<ChannelT, V, V2>::serialize(C, E.second)) + return Err; + } + + return Error::success(); + } + + /// Deserialize a std::map<K, V> to a std::map<K, V>. + static Error deserialize(ChannelT &C, std::map<K2, V2> &M) { + assert(M.empty() && "Expected default-constructed map to deserialize into"); + + uint64_t Count = 0; + if (auto Err = deserializeSeq(C, Count)) + return Err; + + while (Count-- != 0) { + std::pair<K2, V2> Val; + if (auto Err = + SerializationTraits<ChannelT, K, K2>::deserialize(C, Val.first)) + return Err; + + if (auto Err = + SerializationTraits<ChannelT, V, V2>::deserialize(C, Val.second)) + return Err; + + auto Added = M.insert(Val).second; + if (!Added) + return make_error<StringError>("Duplicate element in deserialized map", + orcError(OrcErrorCode::UnknownORCError)); + } + + return Error::success(); + } +}; + +} // end namespace rpc +} // end namespace orc +} // end namespace llvm + +#endif // LLVM_EXECUTIONENGINE_ORC_RPCSERIALIZATION_H |