Skip to content

Commit e3848c1

Browse files
Adam Cozzettecopybara-github
authored andcommitted
Implement enforcement of target constraints
This is a new feature allowing fields to be annotated with a `targets` option specifying what kinds of entities that field may be applied to when used in an option. PiperOrigin-RevId: 527990260
1 parent 37dfe80 commit e3848c1

File tree

4 files changed

+453
-36
lines changed

4 files changed

+453
-36
lines changed

src/google/protobuf/compiler/BUILD.bazel

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ load(
1111
)
1212
load("@rules_proto//proto:defs.bzl", "proto_library")
1313
load("//build_defs:arch_tests.bzl", "aarch64_test", "x86_64_test")
14-
load("//build_defs:cpp_opts.bzl", "COPTS", "LINK_OPTS")
14+
load("//build_defs:cpp_opts.bzl", "COPTS")
1515

1616
proto_library(
1717
name = "plugin_proto",
1818
srcs = ["plugin.proto"],
19+
strip_import_prefix = "/src",
1920
visibility = [
2021
"//:__pkg__",
2122
"//pkg:__pkg__",
2223
],
23-
strip_import_prefix = "/src",
2424
deps = ["//:descriptor_proto"],
2525
)
2626

@@ -93,11 +93,14 @@ cc_library(
9393
"//src/google/protobuf:descriptor_legacy",
9494
"//src/google/protobuf:protobuf_nowkt",
9595
"//src/google/protobuf/compiler/allowlists",
96+
"@com_google_absl//absl/algorithm",
97+
"@com_google_absl//absl/algorithm:container",
9698
"@com_google_absl//absl/container:btree",
9799
"@com_google_absl//absl/log:absl_check",
98100
"@com_google_absl//absl/log:absl_log",
99101
"@com_google_absl//absl/strings",
100102
"@com_google_absl//absl/strings:str_format",
103+
"@com_google_absl//absl/types:span",
101104
],
102105
)
103106

@@ -309,22 +312,21 @@ cc_library(
309312
visibility = ["//src/google/protobuf:__subpackages__"],
310313
deps = [
311314
"//src/google/protobuf:protobuf_nowkt",
312-
"@com_google_absl//absl/types:span",
313315
"@com_google_absl//absl/container:flat_hash_set",
316+
"@com_google_absl//absl/types:span",
314317
],
315318
)
316319

317-
318320
cc_test(
319321
name = "retention_unittest",
320322
srcs = ["retention_unittest.cc"],
321323
deps = [
322324
":importer",
323325
":retention",
324326
"//src/google/protobuf/io",
327+
"@com_google_absl//absl/log:die_if_null",
325328
"@com_google_googletest//:gtest",
326329
"@com_google_googletest//:gtest_main",
327-
"@com_google_absl//absl/log:die_if_null",
328330
],
329331
)
330332

src/google/protobuf/compiler/command_line_interface.cc

Lines changed: 226 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@
3434

3535
#include "google/protobuf/compiler/command_line_interface.h"
3636

37+
#include "absl/algorithm/container.h"
3738
#include "absl/container/btree_set.h"
3839
#include "absl/container/flat_hash_map.h"
40+
#include "absl/types/span.h"
3941
#include "google/protobuf/compiler/allowlists/allowlists.h"
4042
#include "google/protobuf/descriptor_legacy.h"
4143

@@ -1028,7 +1030,20 @@ struct VisitImpl {
10281030
Visitor visitor;
10291031
void Visit(const FieldDescriptor* descriptor) { visitor(descriptor); }
10301032

1031-
void Visit(const EnumDescriptor* descriptor) { visitor(descriptor); }
1033+
void Visit(const EnumValueDescriptor* descriptor) { visitor(descriptor); }
1034+
1035+
void Visit(const EnumDescriptor* descriptor) {
1036+
visitor(descriptor);
1037+
for (int i = 0; i < descriptor->value_count(); i++) {
1038+
Visit(descriptor->value(i));
1039+
}
1040+
}
1041+
1042+
void Visit(const Descriptor::ExtensionRange* descriptor) {
1043+
visitor(descriptor);
1044+
}
1045+
1046+
void Visit(const OneofDescriptor* descriptor) { visitor(descriptor); }
10321047

10331048
void Visit(const Descriptor* descriptor) {
10341049
visitor(descriptor);
@@ -1048,10 +1063,27 @@ struct VisitImpl {
10481063
for (int i = 0; i < descriptor->extension_count(); i++) {
10491064
Visit(descriptor->extension(i));
10501065
}
1066+
1067+
for (int i = 0; i < descriptor->extension_range_count(); i++) {
1068+
Visit(descriptor->extension_range(i));
1069+
}
1070+
1071+
for (int i = 0; i < descriptor->oneof_decl_count(); i++) {
1072+
Visit(descriptor->oneof_decl(i));
1073+
}
10511074
}
10521075

1053-
void Visit(const std::vector<const FileDescriptor*>& descriptors) {
1054-
for (auto* descriptor : descriptors) {
1076+
void Visit(const MethodDescriptor* method) { visitor(method); }
1077+
1078+
void Visit(const ServiceDescriptor* descriptor) {
1079+
visitor(descriptor);
1080+
for (int i = 0; i < descriptor->method_count(); i++) {
1081+
Visit(descriptor->method(i));
1082+
}
1083+
}
1084+
1085+
void Visit(absl::Span<const FileDescriptor*> descriptors) {
1086+
for (const FileDescriptor* descriptor : descriptors) {
10551087
visitor(descriptor);
10561088
for (int i = 0; i < descriptor->message_type_count(); i++) {
10571089
Visit(descriptor->message_type(i));
@@ -1062,17 +1094,18 @@ struct VisitImpl {
10621094
for (int i = 0; i < descriptor->extension_count(); i++) {
10631095
Visit(descriptor->extension(i));
10641096
}
1097+
for (int i = 0; i < descriptor->service_count(); i++) {
1098+
Visit(descriptor->service(i));
1099+
}
10651100
}
10661101
}
10671102
};
10681103

10691104
// Visit every node in the descriptors calling `visitor(node)`.
10701105
// The visitor does not need to handle all possible node types. Types that are
10711106
// not visitable via `visitor` will be ignored.
1072-
// Disclaimer: this is not fully implemented yet to visit _every_ node.
1073-
// VisitImpl might need to be updated where needs arise.
10741107
template <typename Visitor>
1075-
void VisitDescriptors(const std::vector<const FileDescriptor*>& descriptors,
1108+
void VisitDescriptors(absl::Span<const FileDescriptor*> descriptors,
10761109
Visitor visitor) {
10771110
// Provide a fallback to ignore all the nodes that are not interesting to the
10781111
// input visitor.
@@ -1099,8 +1132,151 @@ bool HasReservedFieldNumber(const FieldDescriptor* field) {
10991132
namespace {
11001133
std::unique_ptr<SimpleDescriptorDatabase>
11011134
PopulateSingleSimpleDescriptorDatabase(const std::string& descriptor_set_name);
1135+
1136+
// Indicates whether the field is compatible with the given target type.
1137+
bool IsFieldCompatible(const FieldDescriptor& field,
1138+
FieldOptions::OptionTargetType target_type) {
1139+
const RepeatedField<int>& allowed_targets = field.options().targets();
1140+
return allowed_targets.empty() ||
1141+
absl::c_linear_search(allowed_targets, target_type);
1142+
}
1143+
1144+
// Converts the OptionTargetType enum to a string suitable for use in error
1145+
// messages.
1146+
absl::string_view TargetTypeString(FieldOptions::OptionTargetType target_type) {
1147+
switch (target_type) {
1148+
case FieldOptions::TARGET_TYPE_FILE:
1149+
return "file";
1150+
case FieldOptions::TARGET_TYPE_EXTENSION_RANGE:
1151+
return "extension range";
1152+
case FieldOptions::TARGET_TYPE_MESSAGE:
1153+
return "message";
1154+
case FieldOptions::TARGET_TYPE_FIELD:
1155+
return "field";
1156+
case FieldOptions::TARGET_TYPE_ONEOF:
1157+
return "oneof";
1158+
case FieldOptions::TARGET_TYPE_ENUM:
1159+
return "enum";
1160+
case FieldOptions::TARGET_TYPE_ENUM_ENTRY:
1161+
return "enum entry";
1162+
case FieldOptions::TARGET_TYPE_SERVICE:
1163+
return "service";
1164+
case FieldOptions::TARGET_TYPE_METHOD:
1165+
return "method";
1166+
default:
1167+
return "unknown";
1168+
}
1169+
}
1170+
1171+
// Recursively validates that the options message (or subpiece of an options
1172+
// message) is compatible with the given target type.
1173+
bool ValidateTargetConstraintsRecursive(
1174+
const Message& m, DescriptorPool::ErrorCollector& error_collector,
1175+
absl::string_view file_name, FieldOptions::OptionTargetType target_type) {
1176+
std::vector<const FieldDescriptor*> fields;
1177+
const Reflection* reflection = m.GetReflection();
1178+
reflection->ListFields(m, &fields);
1179+
bool success = true;
1180+
for (const auto* field : fields) {
1181+
if (!IsFieldCompatible(*field, target_type)) {
1182+
success = false;
1183+
error_collector.RecordError(
1184+
file_name, "", nullptr, DescriptorPool::ErrorCollector::OPTION_NAME,
1185+
absl::StrCat("Option ", field->full_name(),
1186+
" cannot be set on an entity of type `",
1187+
TargetTypeString(target_type), "`."));
1188+
}
1189+
if (field->type() == FieldDescriptor::TYPE_MESSAGE) {
1190+
if (field->is_repeated()) {
1191+
int field_size = reflection->FieldSize(m, field);
1192+
for (int i = 0; i < field_size; ++i) {
1193+
if (!ValidateTargetConstraintsRecursive(
1194+
reflection->GetRepeatedMessage(m, field, i), error_collector,
1195+
file_name, target_type)) {
1196+
success = false;
1197+
}
1198+
}
1199+
} else if (!ValidateTargetConstraintsRecursive(
1200+
reflection->GetMessage(m, field), error_collector,
1201+
file_name, target_type)) {
1202+
success = false;
1203+
}
1204+
}
1205+
}
1206+
return success;
1207+
}
1208+
1209+
// Validates that the options message is correct with respect to target
1210+
// constraints, returning true if successful. This function converts the
1211+
// options message to a DynamicMessage so that we have visibility into custom
1212+
// options. We take the element name as a FunctionRef so that we do not have to
1213+
// pay the cost of constructing it unless there is an error.
1214+
bool ValidateTargetConstraints(const Message& options,
1215+
const DescriptorPool& pool,
1216+
DescriptorPool::ErrorCollector& error_collector,
1217+
absl::string_view file_name,
1218+
FieldOptions::OptionTargetType target_type) {
1219+
const Descriptor* descriptor =
1220+
pool.FindMessageTypeByName(options.GetTypeName());
1221+
if (descriptor == nullptr) {
1222+
// We were unable to find the options message in the descriptor pool. This
1223+
// implies that the proto files we are working with do not depend on
1224+
// descriptor.proto, in which case there are no custom options to worry
1225+
// about. We can therefore skip the use of DynamicMessage.
1226+
return ValidateTargetConstraintsRecursive(options, error_collector,
1227+
file_name, target_type);
1228+
} else {
1229+
DynamicMessageFactory factory;
1230+
std::unique_ptr<Message> dynamic_message(
1231+
factory.GetPrototype(descriptor)->New());
1232+
std::string serialized;
1233+
ABSL_CHECK(options.SerializeToString(&serialized));
1234+
ABSL_CHECK(dynamic_message->ParseFromString(serialized));
1235+
return ValidateTargetConstraintsRecursive(*dynamic_message, error_collector,
1236+
file_name, target_type);
1237+
}
1238+
}
1239+
1240+
// The overloaded GetTargetType() functions below allow us to map from a
1241+
// descriptor type to the associated OptionTargetType enum.
1242+
FieldOptions::OptionTargetType GetTargetType(const FileDescriptor*) {
1243+
return FieldOptions::TARGET_TYPE_FILE;
1244+
}
1245+
1246+
FieldOptions::OptionTargetType GetTargetType(
1247+
const Descriptor::ExtensionRange*) {
1248+
return FieldOptions::TARGET_TYPE_EXTENSION_RANGE;
1249+
}
1250+
1251+
FieldOptions::OptionTargetType GetTargetType(const Descriptor*) {
1252+
return FieldOptions::TARGET_TYPE_MESSAGE;
1253+
}
1254+
1255+
FieldOptions::OptionTargetType GetTargetType(const FieldDescriptor*) {
1256+
return FieldOptions::TARGET_TYPE_FIELD;
1257+
}
1258+
1259+
FieldOptions::OptionTargetType GetTargetType(const OneofDescriptor*) {
1260+
return FieldOptions::TARGET_TYPE_ONEOF;
1261+
}
1262+
1263+
FieldOptions::OptionTargetType GetTargetType(const EnumDescriptor*) {
1264+
return FieldOptions::TARGET_TYPE_ENUM;
11021265
}
11031266

1267+
FieldOptions::OptionTargetType GetTargetType(const EnumValueDescriptor*) {
1268+
return FieldOptions::TARGET_TYPE_ENUM_ENTRY;
1269+
}
1270+
1271+
FieldOptions::OptionTargetType GetTargetType(const ServiceDescriptor*) {
1272+
return FieldOptions::TARGET_TYPE_SERVICE;
1273+
}
1274+
1275+
FieldOptions::OptionTargetType GetTargetType(const MethodDescriptor*) {
1276+
return FieldOptions::TARGET_TYPE_METHOD;
1277+
}
1278+
} // namespace
1279+
11041280
int CommandLineInterface::Run(int argc, const char* const argv[]) {
11051281
Clear();
11061282

@@ -1189,31 +1365,50 @@ int CommandLineInterface::Run(int argc, const char* const argv[]) {
11891365

11901366
bool validation_error = false; // Defer exiting so we log more warnings.
11911367

1192-
VisitDescriptors(parsed_files, [&](const FieldDescriptor* field) {
1193-
if (HasReservedFieldNumber(field)) {
1194-
const char* error_link = nullptr;
1195-
validation_error = true;
1196-
std::string error;
1197-
if (field->number() >= FieldDescriptor::kFirstReservedNumber &&
1198-
field->number() <= FieldDescriptor::kLastReservedNumber) {
1199-
error = absl::Substitute(
1200-
"Field numbers $0 through $1 are reserved "
1201-
"for the protocol buffer library implementation.",
1202-
FieldDescriptor::kFirstReservedNumber,
1203-
FieldDescriptor::kLastReservedNumber);
1204-
} else {
1205-
error = absl::Substitute(
1206-
"Field number $0 is reserved for specific purposes.",
1207-
field->number());
1208-
}
1209-
if (error_link) {
1210-
absl::StrAppend(&error, "(See ", error_link, ")");
1211-
}
1212-
static_cast<DescriptorPool::ErrorCollector*>(error_collector.get())
1213-
->RecordError(field->file()->name(), field->full_name(), nullptr,
1214-
DescriptorPool::ErrorCollector::NUMBER, error);
1215-
}
1216-
});
1368+
VisitDescriptors(
1369+
absl::Span<const FileDescriptor*>(parsed_files.data(),
1370+
parsed_files.size()),
1371+
[&](const FieldDescriptor* field) {
1372+
if (HasReservedFieldNumber(field)) {
1373+
const char* error_link = nullptr;
1374+
validation_error = true;
1375+
std::string error;
1376+
if (field->number() >= FieldDescriptor::kFirstReservedNumber &&
1377+
field->number() <= FieldDescriptor::kLastReservedNumber) {
1378+
error = absl::Substitute(
1379+
"Field numbers $0 through $1 are reserved "
1380+
"for the protocol buffer library implementation.",
1381+
FieldDescriptor::kFirstReservedNumber,
1382+
FieldDescriptor::kLastReservedNumber);
1383+
} else {
1384+
error = absl::Substitute(
1385+
"Field number $0 is reserved for specific purposes.",
1386+
field->number());
1387+
}
1388+
if (error_link) {
1389+
absl::StrAppend(&error, "(See ", error_link, ")");
1390+
}
1391+
static_cast<DescriptorPool::ErrorCollector*>(error_collector.get())
1392+
->RecordError(field->file()->name(), field->full_name(), nullptr,
1393+
DescriptorPool::ErrorCollector::NUMBER, error);
1394+
}
1395+
});
1396+
1397+
// We visit one file at a time because we need to provide the file name for
1398+
// error messages. Usually we can get the file name from any descriptor with
1399+
// something like descriptor->file()->name(), but ExtensionRange does not
1400+
// support this.
1401+
for (const google::protobuf::FileDescriptor* file : parsed_files) {
1402+
VisitDescriptors(
1403+
absl::Span<const FileDescriptor*>(&file, 1),
1404+
[&](const auto* descriptor) {
1405+
if (!ValidateTargetConstraints(
1406+
descriptor->options(), *descriptor_pool, *error_collector,
1407+
file->name(), GetTargetType(descriptor))) {
1408+
validation_error = true;
1409+
}
1410+
});
1411+
}
12171412

12181413

12191414
if (validation_error) {

0 commit comments

Comments
 (0)