Skip to content

Commit 396d661

Browse files
Add support for setting extensions with Ptr<Extension> in Upb C++ protos.
PiperOrigin-RevId: 633646216
1 parent 448e326 commit 396d661

File tree

3 files changed

+118
-23
lines changed

3 files changed

+118
-23
lines changed

protos/protos.h

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -340,29 +340,28 @@ ABSL_MUST_USE_RESULT bool HasExtension(
340340
return HasExtension(protos::Ptr(message), id);
341341
}
342342

343-
template <typename T, typename Extendee, typename Extension,
344-
typename = EnableIfProtosClass<T>, typename = EnableIfMutableProto<T>>
343+
template <typename T, typename Extension, typename = EnableIfProtosClass<T>,
344+
typename = EnableIfMutableProto<T>>
345345
void ClearExtension(
346346
Ptr<T> message,
347-
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
347+
const ::protos::internal::ExtensionIdentifier<T, Extension>& id) {
348348
static_assert(!std::is_const_v<T>, "");
349349
upb_Message_ClearExtension(internal::GetInternalMsg(message),
350350
id.mini_table_ext());
351351
}
352352

353-
template <typename T, typename Extendee, typename Extension,
354-
typename = EnableIfProtosClass<T>>
353+
template <typename T, typename Extension, typename = EnableIfProtosClass<T>>
355354
void ClearExtension(
356355
T* message,
357-
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id) {
356+
const ::protos::internal::ExtensionIdentifier<T, Extension>& id) {
358357
ClearExtension(::protos::Ptr(message), id);
359358
}
360359

361-
template <typename T, typename Extendee, typename Extension,
362-
typename = EnableIfProtosClass<T>, typename = EnableIfMutableProto<T>>
360+
template <typename T, typename Extension, typename = EnableIfProtosClass<T>,
361+
typename = EnableIfMutableProto<T>>
363362
absl::Status SetExtension(
364363
Ptr<T> message,
365-
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
364+
const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
366365
const Extension& value) {
367366
static_assert(!std::is_const_v<T>);
368367
auto* message_arena = static_cast<upb_Arena*>(message->GetInternalArena());
@@ -371,11 +370,24 @@ absl::Status SetExtension(
371370
internal::GetInternalMsg(&value));
372371
}
373372

374-
template <typename T, typename Extendee, typename Extension,
375-
typename = EnableIfProtosClass<T>, typename = EnableIfMutableProto<T>>
373+
template <typename T, typename Extension, typename = EnableIfProtosClass<T>,
374+
typename = EnableIfMutableProto<T>>
376375
absl::Status SetExtension(
377376
Ptr<T> message,
378-
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
377+
const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
378+
Ptr<Extension> value) {
379+
static_assert(!std::is_const_v<T>);
380+
auto* message_arena = static_cast<upb_Arena*>(message->GetInternalArena());
381+
return ::protos::internal::SetExtension(internal::GetInternalMsg(message),
382+
message_arena, id.mini_table_ext(),
383+
internal::GetInternalMsg(value));
384+
}
385+
386+
template <typename T, typename Extension, typename = EnableIfProtosClass<T>,
387+
typename = EnableIfMutableProto<T>>
388+
absl::Status SetExtension(
389+
Ptr<T> message,
390+
const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
379391
Extension&& value) {
380392
Extension ext = std::move(value);
381393
static_assert(!std::is_const_v<T>);
@@ -386,25 +398,28 @@ absl::Status SetExtension(
386398
internal::GetInternalMsg(&ext), extension_arena);
387399
}
388400

389-
template <typename T, typename Extendee, typename Extension,
390-
typename = EnableIfProtosClass<T>>
401+
template <typename T, typename Extension, typename = EnableIfProtosClass<T>>
391402
absl::Status SetExtension(
392-
T* message,
393-
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
403+
T* message, const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
394404
const Extension& value) {
395405
return ::protos::SetExtension(::protos::Ptr(message), id, value);
396406
}
397407

398-
template <typename T, typename Extendee, typename Extension,
399-
typename = EnableIfProtosClass<T>>
408+
template <typename T, typename Extension, typename = EnableIfProtosClass<T>>
400409
absl::Status SetExtension(
401-
T* message,
402-
const ::protos::internal::ExtensionIdentifier<Extendee, Extension>& id,
410+
T* message, const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
403411
Extension&& value) {
404412
return ::protos::SetExtension(::protos::Ptr(message), id,
405413
std::forward<Extension>(value));
406414
}
407415

416+
template <typename T, typename Extension, typename = EnableIfProtosClass<T>>
417+
absl::Status SetExtension(
418+
T* message, const ::protos::internal::ExtensionIdentifier<T, Extension>& id,
419+
Ptr<Extension> value) {
420+
return ::protos::SetExtension(::protos::Ptr(message), id, value);
421+
}
422+
408423
template <typename T, typename Extendee, typename Extension,
409424
typename = EnableIfProtosClass<T>>
410425
absl::StatusOr<Ptr<const Extension>> GetExtension(

protos_generator/tests/test_generated.cc

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
// license that can be found in the LICENSE file or at
66
// https://developers.google.com/open-source/licenses/bsd
77

8+
#include <cstdint>
89
#include <iterator>
910
#include <limits>
1011
#include <memory>
@@ -24,10 +25,13 @@
2425
#include "protos_generator/tests/no_package.upb.proto.h"
2526
#include "protos_generator/tests/test_model.upb.proto.h"
2627
#include "upb/mem/arena.h"
28+
#include "upb/mem/arena.hpp"
2729

2830
namespace {
2931

3032
using ::protos_generator::test::protos::ChildModel1;
33+
using ::protos_generator::test::protos::container_ext;
34+
using ::protos_generator::test::protos::ContainerExtension;
3135
using ::protos_generator::test::protos::other_ext;
3236
using ::protos_generator::test::protos::RED;
3337
using ::protos_generator::test::protos::TestEnum;
@@ -39,7 +43,6 @@ using ::protos_generator::test::protos::TestModel_Category_VIDEO;
3943
using ::protos_generator::test::protos::theme;
4044
using ::protos_generator::test::protos::ThemeExtension;
4145
using ::testing::ElementsAre;
42-
using ::testing::HasSubstr;
4346

4447
// C++17 port of C++20 `requires`
4548
template <typename... T, typename F>
@@ -442,7 +445,7 @@ TEST(CppGeneratedCode, RepeatedScalarIterator) {
442445
EXPECT_EQ(sum, 5 + 16 + 27);
443446
// Access by const reference.
444447
sum = 0;
445-
for (const int& i : *test_model.mutable_value_array()) {
448+
for (const auto& i : *test_model.mutable_value_array()) {
446449
sum += i;
447450
}
448451
EXPECT_EQ(sum, 5 + 16 + 27);
@@ -551,7 +554,7 @@ TEST(CppGeneratedCode, RepeatedFieldProxyForMessages) {
551554
}
552555

553556
i = 0;
554-
for (auto child : *test_model.mutable_child_models()) {
557+
for (const auto& child : *test_model.mutable_child_models()) {
555558
if (i++ == 0) {
556559
EXPECT_EQ(child.child_str1(), kTestStr1);
557560
} else {
@@ -725,6 +728,70 @@ TEST(CppGeneratedCode, SetExtension) {
725728
EXPECT_EQ(::protos::internal::GetInternalMsg(*ext), prior_message);
726729
}
727730

731+
TEST(CppGeneratedCode, SetExtensionWithPtr) {
732+
::protos::Arena arena_model;
733+
::protos::Ptr<TestModel> model =
734+
::protos::CreateMessage<TestModel>(arena_model);
735+
void* prior_message;
736+
{
737+
// Use a nested scope to make sure the arenas are fused correctly.
738+
::protos::Arena arena;
739+
::protos::Ptr<ThemeExtension> extension1 =
740+
::protos::CreateMessage<ThemeExtension>(arena);
741+
extension1->set_ext_name("Hello World");
742+
prior_message = ::protos::internal::GetInternalMsg(extension1);
743+
EXPECT_EQ(false, ::protos::HasExtension(model, theme));
744+
auto res = ::protos::SetExtension(model, theme, extension1);
745+
EXPECT_EQ(true, res.ok());
746+
}
747+
EXPECT_EQ(true, ::protos::HasExtension(model, theme));
748+
auto ext = ::protos::GetExtension(model, theme);
749+
EXPECT_TRUE(ext.ok());
750+
EXPECT_NE(::protos::internal::GetInternalMsg(*ext), prior_message);
751+
}
752+
753+
#ifndef _MSC_VER
754+
TEST(CppGeneratedCode, SetExtensionShouldNotCompileForWrongType) {
755+
::protos::Arena arena;
756+
::protos::Ptr<TestModel> model = ::protos::CreateMessage<TestModel>(arena);
757+
ThemeExtension extension1;
758+
ContainerExtension extension2;
759+
760+
const auto canSetExtension = [&](auto l) {
761+
return Requires<decltype(model)>(l);
762+
};
763+
EXPECT_TRUE(canSetExtension(
764+
[](auto p) -> decltype(::protos::SetExtension(p, theme, extension1)) {}));
765+
// Wrong extension value type should fail to compile.
766+
EXPECT_TRUE(!canSetExtension(
767+
[](auto p) -> decltype(::protos::SetExtension(p, theme, extension2)) {}));
768+
// Wrong extension id with correct extension type should fail to compile.
769+
EXPECT_TRUE(
770+
!canSetExtension([](auto p) -> decltype(::protos::SetExtension(
771+
p, container_ext, extension1)) {}));
772+
}
773+
#endif
774+
775+
TEST(CppGeneratedCode, SetExtensionWithPtrSameArena) {
776+
::protos::Arena arena;
777+
::protos::Ptr<TestModel> model = ::protos::CreateMessage<TestModel>(arena);
778+
void* prior_message;
779+
{
780+
// Use a nested scope to make sure the arenas are fused correctly.
781+
::protos::Ptr<ThemeExtension> extension1 =
782+
::protos::CreateMessage<ThemeExtension>(arena);
783+
extension1->set_ext_name("Hello World");
784+
prior_message = ::protos::internal::GetInternalMsg(extension1);
785+
EXPECT_EQ(false, ::protos::HasExtension(model, theme));
786+
auto res = ::protos::SetExtension(model, theme, extension1);
787+
EXPECT_EQ(true, res.ok());
788+
}
789+
EXPECT_EQ(true, ::protos::HasExtension(model, theme));
790+
auto ext = ::protos::GetExtension(model, theme);
791+
EXPECT_TRUE(ext.ok());
792+
EXPECT_NE(::protos::internal::GetInternalMsg(*ext), prior_message);
793+
}
794+
728795
TEST(CppGeneratedCode, SetExtensionFusingFailureShouldCopy) {
729796
// Use an initial block to disallow fusing.
730797
char initial_block[1000];

protos_generator/tests/test_model.proto

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ import "protos_generator/tests/child_model.proto";
1414
message TestModelContainer {
1515
repeated TestModel models = 1;
1616
optional ChildModel3 proto_3_child = 2;
17+
extensions 10000 to max
18+
[verification = UNVERIFIED];
1719
}
1820

1921
message TestModel {
@@ -138,6 +140,17 @@ extend TestModel {
138140
optional ThemeExtension theme = 12001;
139141
}
140142

143+
message ContainerExtension {
144+
extend TestModelContainer {
145+
optional ContainerExtension container_extension = 12004;
146+
}
147+
optional string ext_container_name = 1;
148+
}
149+
150+
extend TestModelContainer {
151+
optional ContainerExtension container_ext = 12005;
152+
}
153+
141154
message OtherExtension {
142155
optional string ext2_name = 1;
143156
}

0 commit comments

Comments
 (0)