Skip to content

Commit

Permalink
CursorSerializer 12/n: Support writing with cursor inside a container
Browse files Browse the repository at this point in the history
Summary:
Adds beginWrite/endWrite methods for both structured types and containers inside lists and sets.
This API would be fairly clunky for maps as the caller would have to alternate between writing keys and values, so that is not implemented in this diff.

Reviewed By: thedavekwon

Differential Revision: D57298656

fbshipit-source-id: 5aefed0caa3bbe4e4926f855c27cdb6ee7b08c07
  • Loading branch information
iahs authored and facebook-github-bot committed May 17, 2024
1 parent 7276efd commit 3843f1c
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -761,8 +761,10 @@ class StructuredCursorWriter : detail::BaseCursorWriter {

FieldId fieldId_{0};

template <typename U>
template <typename>
friend class StructuredCursorWriter;
template <typename>
friend class ContainerCursorWriter;
friend class CursorSerializationWrapper<T>;
friend struct detail::DefaultValueWriter<Tag>;
};
Expand All @@ -780,17 +782,81 @@ class StructuredCursorWriter : detail::BaseCursorWriter {
*/
template <typename Tag>
class ContainerCursorWriter : detail::DelayedSizeCursorWriter {
using ElementType = typename detail::ContainerTraits<Tag>::ElementType;
using ElementTag = typename detail::ContainerTraits<Tag>::ElementTag;
template <typename CTag, typename OwnTag>
using enable_cursor_for = std::enable_if_t<
(type::is_a_v<OwnTag, type::list_c> ||
type::is_a_v<OwnTag, type::set_c>) &&
type::is_a_v<ElementTag, CTag>,
int>;

public:
void write(const typename detail::ContainerTraits<Tag>::ElementType& val) {
void write(const ElementType& val) {
checkState(State::Active);
++n;
detail::ContainerTraits<Tag>::write(protocol_, val);
}

/**
* Allows writing containers whose size isn't known until afterwards.
* Less efficient than using write().
* See the ContainerCursorWriter docblock for example usage.
*
* Note: none of this writer's other methods may be called between
* beginWrite() and the corresponding endWrite().
*/
template <
typename...,
typename U = Tag,
enable_cursor_for<type::container_c, U> = 0>
ContainerCursorWriter<ElementTag> beginWrite() {
checkState(State::Active);
state_ = State::Child;
return ContainerCursorWriter<ElementTag>{std::move(protocol_)};
}

template <typename CTag>
void endWrite(ContainerCursorWriter<CTag>&& child) {
checkState(State::Child);
child.finalize();
protocol_ = std::move(child.protocol_);
++n;
state_ = State::Active;
}

/**
* structured types
*
* Note: none of this writer's other methods may be called between
* beginWrite() and the corresponding endWrite().
*/
template <
typename...,
typename U = Tag,
enable_cursor_for<type::structured_c, U> = 0>
StructuredCursorWriter<ElementTag> beginWrite() {
checkState(State::Active);
state_ = State::Child;
return StructuredCursorWriter<ElementTag>{std::move(protocol_)};
}

template <typename CTag>
void endWrite(StructuredCursorWriter<CTag>&& child) {
checkState(State::Child);
child.finalize();
protocol_ = std::move(child.protocol_);
++n;
state_ = State::Active;
}

private:
explicit ContainerCursorWriter(BinaryProtocolWriter&& p);

template <typename T>
template <typename>
friend class StructuredCursorWriter;
template <typename>
friend class ContainerCursorWriter;

void finalize() { DelayedSizeCursorWriter::finalize(n); }

Expand Down Expand Up @@ -860,10 +926,10 @@ ContainerCursorWriter<Tag>::ContainerCursorWriter(BinaryProtocolWriter&& p)
: DelayedSizeCursorWriter(std::move(p)) {
if constexpr (type::is_a_v<Tag, type::list_c>) {
protocol_.writeByte(
op::typeTagToTType<typename detail::ContainerTraits<Tag>::ValueTag>);
op::typeTagToTType<typename detail::ContainerTraits<Tag>::ElementTag>);
} else if constexpr (type::is_a_v<Tag, type::set_c>) {
protocol_.writeByte(
op::typeTagToTType<typename detail::ContainerTraits<Tag>::KeyTag>);
op::typeTagToTType<typename detail::ContainerTraits<Tag>::ElementTag>);
} else if constexpr (type::is_a_v<Tag, type::map_c>) {
protocol_.writeByte(
op::typeTagToTType<typename detail::ContainerTraits<Tag>::KeyTag>);
Expand Down Expand Up @@ -908,15 +974,15 @@ ContainerCursorReader<Tag>::ContainerCursorReader(BinaryProtocolReader&& p)
TType type;
protocol_.readListBegin(type, remaining_);
if (type !=
op::typeTagToTType<typename detail::ContainerTraits<Tag>::ValueTag>) {
op::typeTagToTType<typename detail::ContainerTraits<Tag>::ElementTag>) {
folly::throw_exception<std::runtime_error>(
"Unexpected element type in list");
}
} else if constexpr (type::is_a_v<Tag, type::set_c>) {
TType type;
protocol_.readSetBegin(type, remaining_);
if (type !=
op::typeTagToTType<typename detail::ContainerTraits<Tag>::KeyTag>) {
op::typeTagToTType<typename detail::ContainerTraits<Tag>::ElementTag>) {
folly::throw_exception<std::runtime_error>(
"Unexpected element type in set");
}
Expand Down Expand Up @@ -965,10 +1031,10 @@ void ContainerCursorReader<Tag>::read() {
DCHECK_GT(remaining_, 0);

if constexpr (type::is_a_v<Tag, type::list_c>) {
op::decode<typename detail::ContainerTraits<Tag>::ValueTag>(
op::decode<typename detail::ContainerTraits<Tag>::ElementTag>(
protocol_, lastRead_);
} else if constexpr (type::is_a_v<Tag, type::set_c>) {
op::decode<typename detail::ContainerTraits<Tag>::KeyTag>(
op::decode<typename detail::ContainerTraits<Tag>::ElementTag>(
protocol_, lastRead_);
} else if constexpr (type::is_a_v<Tag, type::map_c>) {
op::decode<typename detail::ContainerTraits<Tag>::KeyTag>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ struct ContainerTraits;
template <typename VTag>
struct ContainerTraits<type::list<VTag>> {
using ElementType = type::native_type<VTag>;
using ValueTag = VTag;
using ElementTag = VTag;
// This is initializer_list becuase that's what skip_n accepts.
static constexpr std::initializer_list<protocol::TType> wireTypes = {
op::typeTagToTType<ValueTag>};
op::typeTagToTType<ElementTag>};

static void write(BinaryProtocolWriter& protocol, const ElementType& value) {
op::encode<VTag>(protocol, value);
Expand All @@ -68,9 +68,9 @@ struct ContainerTraits<type::list<VTag>> {
template <typename KTag>
struct ContainerTraits<type::set<KTag>> {
using ElementType = type::native_type<KTag>;
using KeyTag = KTag;
using ElementTag = KTag;
static constexpr std::initializer_list<protocol::TType> wireTypes = {
op::typeTagToTType<KeyTag>};
op::typeTagToTType<ElementTag>};

static void write(BinaryProtocolWriter& protocol, const ElementType& value) {
op::encode<KTag>(protocol, value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <thrift/lib/cpp2/protocol/test/gen-cpp2/cursor_handlers.h>
#include <thrift/lib/cpp2/protocol/test/gen-cpp2/cursor_types.h>
#include <thrift/lib/cpp2/util/ScopedServerInterfaceThread.h>
#include <thrift/lib/cpp2/util/gtest/Matcher.h>

using namespace apache::thrift;
using namespace apache::thrift::test;
Expand Down Expand Up @@ -427,3 +428,22 @@ TEST(CursorSerializer, NestedStructWrite) {
EXPECT_EQ(obj.union_field()->getType(), Inner::Type::__EMPTY__);
EXPECT_THAT(*obj.list_field(), ElementsAre(42));
}

TEST(CursorSerializer, CursorWriteInContainer) {
StructCursor wrapper;
auto writer = wrapper.beginWrite();
auto listWriter = writer.beginWrite<ident::set_nested_field>();
auto setWriter = listWriter.beginWrite();
auto innerWriter = setWriter.beginWrite();
innerWriter.write<ident::string_field>("foo");
setWriter.endWrite(std::move(innerWriter));
listWriter.endWrite(std::move(setWriter));
writer.endWrite(std::move(listWriter));
wrapper.endWrite(std::move(writer));

auto obj = wrapper.deserialize();
LOG(INFO) << debugStringViaEncode(obj);
EXPECT_THAT(
*obj.set_nested_field(),
Contains(Contains(IsThriftUnionWith<ident::string_field>(Eq("foo")))));
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct Struct {
2: i32 i32_field;
3: Inner union_field;
4: list<byte> list_field;
5: list<set<string>> set_nested_field;
5: list<set<Stringish>> set_nested_field;
@cpp.Type{template = "std::unordered_map"}
6: map<byte, byte> map_field;
}
Expand Down

0 comments on commit 3843f1c

Please sign in to comment.