3434
3535#include " google/protobuf/compiler/command_line_interface.h"
3636
37+ #include " absl/algorithm/container.h"
3738#include " absl/container/btree_set.h"
3839#include " absl/container/flat_hash_map.h"
40+ #include " absl/types/span.h"
3941#include " google/protobuf/compiler/allowlists/allowlists.h"
4042#include " google/protobuf/descriptor_legacy.h"
4143
@@ -1028,7 +1030,20 @@ struct VisitImpl {
10281030 Visitor visitor;
10291031 void Visit (const FieldDescriptor* descriptor) { visitor (descriptor); }
10301032
1031- void Visit (const EnumDescriptor* descriptor) { visitor (descriptor); }
1033+ void Visit (const EnumValueDescriptor* descriptor) { visitor (descriptor); }
1034+
1035+ void Visit (const EnumDescriptor* descriptor) {
1036+ visitor (descriptor);
1037+ for (int i = 0 ; i < descriptor->value_count (); i++) {
1038+ Visit (descriptor->value (i));
1039+ }
1040+ }
1041+
1042+ void Visit (const Descriptor::ExtensionRange* descriptor) {
1043+ visitor (descriptor);
1044+ }
1045+
1046+ void Visit (const OneofDescriptor* descriptor) { visitor (descriptor); }
10321047
10331048 void Visit (const Descriptor* descriptor) {
10341049 visitor (descriptor);
@@ -1048,10 +1063,27 @@ struct VisitImpl {
10481063 for (int i = 0 ; i < descriptor->extension_count (); i++) {
10491064 Visit (descriptor->extension (i));
10501065 }
1066+
1067+ for (int i = 0 ; i < descriptor->extension_range_count (); i++) {
1068+ Visit (descriptor->extension_range (i));
1069+ }
1070+
1071+ for (int i = 0 ; i < descriptor->oneof_decl_count (); i++) {
1072+ Visit (descriptor->oneof_decl (i));
1073+ }
10511074 }
10521075
1053- void Visit (const std::vector<const FileDescriptor*>& descriptors) {
1054- for (auto * descriptor : descriptors) {
1076+ void Visit (const MethodDescriptor* method) { visitor (method); }
1077+
1078+ void Visit (const ServiceDescriptor* descriptor) {
1079+ visitor (descriptor);
1080+ for (int i = 0 ; i < descriptor->method_count (); i++) {
1081+ Visit (descriptor->method (i));
1082+ }
1083+ }
1084+
1085+ void Visit (absl::Span<const FileDescriptor*> descriptors) {
1086+ for (const FileDescriptor* descriptor : descriptors) {
10551087 visitor (descriptor);
10561088 for (int i = 0 ; i < descriptor->message_type_count (); i++) {
10571089 Visit (descriptor->message_type (i));
@@ -1062,17 +1094,18 @@ struct VisitImpl {
10621094 for (int i = 0 ; i < descriptor->extension_count (); i++) {
10631095 Visit (descriptor->extension (i));
10641096 }
1097+ for (int i = 0 ; i < descriptor->service_count (); i++) {
1098+ Visit (descriptor->service (i));
1099+ }
10651100 }
10661101 }
10671102};
10681103
10691104// Visit every node in the descriptors calling `visitor(node)`.
10701105// The visitor does not need to handle all possible node types. Types that are
10711106// not visitable via `visitor` will be ignored.
1072- // Disclaimer: this is not fully implemented yet to visit _every_ node.
1073- // VisitImpl might need to be updated where needs arise.
10741107template <typename Visitor>
1075- void VisitDescriptors (const std::vector <const FileDescriptor*>& descriptors,
1108+ void VisitDescriptors (absl::Span <const FileDescriptor*> descriptors,
10761109 Visitor visitor) {
10771110 // Provide a fallback to ignore all the nodes that are not interesting to the
10781111 // input visitor.
@@ -1099,8 +1132,151 @@ bool HasReservedFieldNumber(const FieldDescriptor* field) {
10991132namespace {
11001133std::unique_ptr<SimpleDescriptorDatabase>
11011134PopulateSingleSimpleDescriptorDatabase (const std::string& descriptor_set_name);
1135+
1136+ // Indicates whether the field is compatible with the given target type.
1137+ bool IsFieldCompatible (const FieldDescriptor& field,
1138+ FieldOptions::OptionTargetType target_type) {
1139+ const RepeatedField<int >& allowed_targets = field.options ().targets ();
1140+ return allowed_targets.empty () ||
1141+ absl::c_linear_search (allowed_targets, target_type);
1142+ }
1143+
1144+ // Converts the OptionTargetType enum to a string suitable for use in error
1145+ // messages.
1146+ absl::string_view TargetTypeString (FieldOptions::OptionTargetType target_type) {
1147+ switch (target_type) {
1148+ case FieldOptions::TARGET_TYPE_FILE:
1149+ return " file" ;
1150+ case FieldOptions::TARGET_TYPE_EXTENSION_RANGE:
1151+ return " extension range" ;
1152+ case FieldOptions::TARGET_TYPE_MESSAGE:
1153+ return " message" ;
1154+ case FieldOptions::TARGET_TYPE_FIELD:
1155+ return " field" ;
1156+ case FieldOptions::TARGET_TYPE_ONEOF:
1157+ return " oneof" ;
1158+ case FieldOptions::TARGET_TYPE_ENUM:
1159+ return " enum" ;
1160+ case FieldOptions::TARGET_TYPE_ENUM_ENTRY:
1161+ return " enum entry" ;
1162+ case FieldOptions::TARGET_TYPE_SERVICE:
1163+ return " service" ;
1164+ case FieldOptions::TARGET_TYPE_METHOD:
1165+ return " method" ;
1166+ default :
1167+ return " unknown" ;
1168+ }
1169+ }
1170+
1171+ // Recursively validates that the options message (or subpiece of an options
1172+ // message) is compatible with the given target type.
1173+ bool ValidateTargetConstraintsRecursive (
1174+ const Message& m, DescriptorPool::ErrorCollector& error_collector,
1175+ absl::string_view file_name, FieldOptions::OptionTargetType target_type) {
1176+ std::vector<const FieldDescriptor*> fields;
1177+ const Reflection* reflection = m.GetReflection ();
1178+ reflection->ListFields (m, &fields);
1179+ bool success = true ;
1180+ for (const auto * field : fields) {
1181+ if (!IsFieldCompatible (*field, target_type)) {
1182+ success = false ;
1183+ error_collector.RecordError (
1184+ file_name, " " , nullptr , DescriptorPool::ErrorCollector::OPTION_NAME,
1185+ absl::StrCat (" Option " , field->full_name (),
1186+ " cannot be set on an entity of type `" ,
1187+ TargetTypeString (target_type), " `." ));
1188+ }
1189+ if (field->type () == FieldDescriptor::TYPE_MESSAGE) {
1190+ if (field->is_repeated ()) {
1191+ int field_size = reflection->FieldSize (m, field);
1192+ for (int i = 0 ; i < field_size; ++i) {
1193+ if (!ValidateTargetConstraintsRecursive (
1194+ reflection->GetRepeatedMessage (m, field, i), error_collector,
1195+ file_name, target_type)) {
1196+ success = false ;
1197+ }
1198+ }
1199+ } else if (!ValidateTargetConstraintsRecursive (
1200+ reflection->GetMessage (m, field), error_collector,
1201+ file_name, target_type)) {
1202+ success = false ;
1203+ }
1204+ }
1205+ }
1206+ return success;
1207+ }
1208+
1209+ // Validates that the options message is correct with respect to target
1210+ // constraints, returning true if successful. This function converts the
1211+ // options message to a DynamicMessage so that we have visibility into custom
1212+ // options. We take the element name as a FunctionRef so that we do not have to
1213+ // pay the cost of constructing it unless there is an error.
1214+ bool ValidateTargetConstraints (const Message& options,
1215+ const DescriptorPool& pool,
1216+ DescriptorPool::ErrorCollector& error_collector,
1217+ absl::string_view file_name,
1218+ FieldOptions::OptionTargetType target_type) {
1219+ const Descriptor* descriptor =
1220+ pool.FindMessageTypeByName (options.GetTypeName ());
1221+ if (descriptor == nullptr ) {
1222+ // We were unable to find the options message in the descriptor pool. This
1223+ // implies that the proto files we are working with do not depend on
1224+ // descriptor.proto, in which case there are no custom options to worry
1225+ // about. We can therefore skip the use of DynamicMessage.
1226+ return ValidateTargetConstraintsRecursive (options, error_collector,
1227+ file_name, target_type);
1228+ } else {
1229+ DynamicMessageFactory factory;
1230+ std::unique_ptr<Message> dynamic_message (
1231+ factory.GetPrototype (descriptor)->New ());
1232+ std::string serialized;
1233+ ABSL_CHECK (options.SerializeToString (&serialized));
1234+ ABSL_CHECK (dynamic_message->ParseFromString (serialized));
1235+ return ValidateTargetConstraintsRecursive (*dynamic_message, error_collector,
1236+ file_name, target_type);
1237+ }
1238+ }
1239+
1240+ // The overloaded GetTargetType() functions below allow us to map from a
1241+ // descriptor type to the associated OptionTargetType enum.
1242+ FieldOptions::OptionTargetType GetTargetType (const FileDescriptor*) {
1243+ return FieldOptions::TARGET_TYPE_FILE;
1244+ }
1245+
1246+ FieldOptions::OptionTargetType GetTargetType (
1247+ const Descriptor::ExtensionRange*) {
1248+ return FieldOptions::TARGET_TYPE_EXTENSION_RANGE;
1249+ }
1250+
1251+ FieldOptions::OptionTargetType GetTargetType (const Descriptor*) {
1252+ return FieldOptions::TARGET_TYPE_MESSAGE;
1253+ }
1254+
1255+ FieldOptions::OptionTargetType GetTargetType (const FieldDescriptor*) {
1256+ return FieldOptions::TARGET_TYPE_FIELD;
1257+ }
1258+
1259+ FieldOptions::OptionTargetType GetTargetType (const OneofDescriptor*) {
1260+ return FieldOptions::TARGET_TYPE_ONEOF;
1261+ }
1262+
1263+ FieldOptions::OptionTargetType GetTargetType (const EnumDescriptor*) {
1264+ return FieldOptions::TARGET_TYPE_ENUM;
11021265}
11031266
1267+ FieldOptions::OptionTargetType GetTargetType (const EnumValueDescriptor*) {
1268+ return FieldOptions::TARGET_TYPE_ENUM_ENTRY;
1269+ }
1270+
1271+ FieldOptions::OptionTargetType GetTargetType (const ServiceDescriptor*) {
1272+ return FieldOptions::TARGET_TYPE_SERVICE;
1273+ }
1274+
1275+ FieldOptions::OptionTargetType GetTargetType (const MethodDescriptor*) {
1276+ return FieldOptions::TARGET_TYPE_METHOD;
1277+ }
1278+ } // namespace
1279+
11041280int CommandLineInterface::Run (int argc, const char * const argv[]) {
11051281 Clear ();
11061282
@@ -1189,31 +1365,50 @@ int CommandLineInterface::Run(int argc, const char* const argv[]) {
11891365
11901366 bool validation_error = false ; // Defer exiting so we log more warnings.
11911367
1192- VisitDescriptors (parsed_files, [&](const FieldDescriptor* field) {
1193- if (HasReservedFieldNumber (field)) {
1194- const char * error_link = nullptr ;
1195- validation_error = true ;
1196- std::string error;
1197- if (field->number () >= FieldDescriptor::kFirstReservedNumber &&
1198- field->number () <= FieldDescriptor::kLastReservedNumber ) {
1199- error = absl::Substitute (
1200- " Field numbers $0 through $1 are reserved "
1201- " for the protocol buffer library implementation." ,
1202- FieldDescriptor::kFirstReservedNumber ,
1203- FieldDescriptor::kLastReservedNumber );
1204- } else {
1205- error = absl::Substitute (
1206- " Field number $0 is reserved for specific purposes." ,
1207- field->number ());
1208- }
1209- if (error_link) {
1210- absl::StrAppend (&error, " (See " , error_link, " )" );
1211- }
1212- static_cast <DescriptorPool::ErrorCollector*>(error_collector.get ())
1213- ->RecordError (field->file ()->name (), field->full_name (), nullptr ,
1214- DescriptorPool::ErrorCollector::NUMBER, error);
1215- }
1216- });
1368+ VisitDescriptors (
1369+ absl::Span<const FileDescriptor*>(parsed_files.data (),
1370+ parsed_files.size ()),
1371+ [&](const FieldDescriptor* field) {
1372+ if (HasReservedFieldNumber (field)) {
1373+ const char * error_link = nullptr ;
1374+ validation_error = true ;
1375+ std::string error;
1376+ if (field->number () >= FieldDescriptor::kFirstReservedNumber &&
1377+ field->number () <= FieldDescriptor::kLastReservedNumber ) {
1378+ error = absl::Substitute (
1379+ " Field numbers $0 through $1 are reserved "
1380+ " for the protocol buffer library implementation." ,
1381+ FieldDescriptor::kFirstReservedNumber ,
1382+ FieldDescriptor::kLastReservedNumber );
1383+ } else {
1384+ error = absl::Substitute (
1385+ " Field number $0 is reserved for specific purposes." ,
1386+ field->number ());
1387+ }
1388+ if (error_link) {
1389+ absl::StrAppend (&error, " (See " , error_link, " )" );
1390+ }
1391+ static_cast <DescriptorPool::ErrorCollector*>(error_collector.get ())
1392+ ->RecordError (field->file ()->name (), field->full_name (), nullptr ,
1393+ DescriptorPool::ErrorCollector::NUMBER, error);
1394+ }
1395+ });
1396+
1397+ // We visit one file at a time because we need to provide the file name for
1398+ // error messages. Usually we can get the file name from any descriptor with
1399+ // something like descriptor->file()->name(), but ExtensionRange does not
1400+ // support this.
1401+ for (const google::protobuf::FileDescriptor* file : parsed_files) {
1402+ VisitDescriptors (
1403+ absl::Span<const FileDescriptor*>(&file, 1 ),
1404+ [&](const auto * descriptor) {
1405+ if (!ValidateTargetConstraints (
1406+ descriptor->options (), *descriptor_pool, *error_collector,
1407+ file->name (), GetTargetType (descriptor))) {
1408+ validation_error = true ;
1409+ }
1410+ });
1411+ }
12171412
12181413
12191414 if (validation_error) {
0 commit comments