Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 61 additions & 13 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,14 @@ impl<'b> CodeGenerator<'_, 'b> {
let boxed = self
.context
.should_box_message_field(fq_message_name, &field.descriptor);
let ty = self.resolve_type(&field.descriptor, fq_message_name);
let custom_module_path = self
.context
.get_custom_scalar_module_path(&field.descriptor, fq_message_name);
let ty = self.resolve_type(
&field.descriptor,
fq_message_name,
custom_module_path.as_deref(),
);

debug!(
" field: {:?}, type: {:?}, boxed: {}",
Expand All @@ -440,10 +447,10 @@ impl<'b> CodeGenerator<'_, 'b> {

self.push_indent();
self.buf.push_str("#[prost(");
let type_tag = self.field_type_tag(&field.descriptor);
let type_tag = self.field_type_tag(&field.descriptor, custom_module_path.as_deref());
self.buf.push_str(&type_tag);

if type_ == Type::Bytes {
if type_ == Type::Bytes && custom_module_path.is_none() {
let bytes_type = self
.context
.bytes_type(fq_message_name, field.descriptor.name());
Expand Down Expand Up @@ -545,8 +552,14 @@ impl<'b> CodeGenerator<'_, 'b> {
key: &FieldDescriptorProto,
value: &FieldDescriptorProto,
) {
let key_ty = self.resolve_type(key, fq_message_name);
let value_ty = self.resolve_type(value, fq_message_name);
let field_name = format!("{}.{}", fq_message_name, field.descriptor.name());
let custom_module_path_key = self.context.get_custom_scalar_module_path(key, &field_name);
let custom_module_path_value = self
.context
.get_custom_scalar_module_path(value, &field_name);
let key_ty = self.resolve_type(key, fq_message_name, custom_module_path_key.as_deref());
let value_ty =
self.resolve_type(value, fq_message_name, custom_module_path_value.as_deref());

debug!(
" map field: {:?}, key type: {:?}, value type: {:?}",
Expand All @@ -561,8 +574,8 @@ impl<'b> CodeGenerator<'_, 'b> {
let map_type = self
.context
.map_type(fq_message_name, field.descriptor.name());
let key_tag = self.field_type_tag(key);
let value_tag = self.map_value_type_tag(value);
let key_tag = self.field_type_tag(key, custom_module_path_key.as_deref());
let value_tag = self.map_value_type_tag(value, custom_module_path_value.as_deref());

self.buf.push_str(&format!(
"#[prost({} = \"{}, {}\", tag = \"{}\")]\n",
Expand Down Expand Up @@ -659,7 +672,10 @@ impl<'b> CodeGenerator<'_, 'b> {
}

self.push_indent();
let ty_tag = self.field_type_tag(&field.descriptor);
let custom_module_path = self
.context
.get_custom_scalar_module_path(&field.descriptor, fq_message_name);
let ty_tag = self.field_type_tag(&field.descriptor, custom_module_path.as_deref());
self.buf.push_str(&format!(
"#[prost({}, tag = \"{}\")]\n",
ty_tag,
Expand All @@ -668,7 +684,11 @@ impl<'b> CodeGenerator<'_, 'b> {
self.append_field_attributes(&oneof_name, field.descriptor.name());

self.push_indent();
let ty = self.resolve_type(&field.descriptor, fq_message_name);
let ty = self.resolve_type(
&field.descriptor,
fq_message_name,
custom_module_path.as_deref(),
);

let boxed = self.context.should_box_oneof_field(
fq_message_name,
Expand Down Expand Up @@ -973,7 +993,20 @@ impl<'b> CodeGenerator<'_, 'b> {
self.buf.push_str("}\n");
}

fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String {
fn resolve_type(
&self,
field: &FieldDescriptorProto,
fq_message_name: &str,
custom_scalar_module_path: Option<&str>,
) -> String {
if let Some(module_path) = custom_scalar_module_path {
return format!(
"<{} as {}::encoding::CustomScalarInterface>::Type",
module_path,
self.context.prost_path()
);
}

match field.r#type() {
Type::Float => String::from("f32"),
Type::Double => String::from("f64"),
Expand Down Expand Up @@ -1030,7 +1063,14 @@ impl<'b> CodeGenerator<'_, 'b> {
.join("::")
}

fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
fn field_type_tag(
&self,
field: &FieldDescriptorProto,
custom_scalar_module_path: Option<&str>,
) -> Cow<'static, str> {
if let Some(module_path) = custom_scalar_module_path {
return Cow::Owned(format!("custom_scalar({})", module_path));
}
match field.r#type() {
Type::Float => Cow::Borrowed("float"),
Type::Double => Cow::Borrowed("double"),
Expand All @@ -1056,13 +1096,21 @@ impl<'b> CodeGenerator<'_, 'b> {
}
}

fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
fn map_value_type_tag(
&self,
field: &FieldDescriptorProto,
custom_scalar_module_path: Option<&str>,
) -> Cow<'static, str> {
if let Some(module_path) = custom_scalar_module_path {
return Cow::Owned(format!("custom_scalar({})", module_path));
}
match field.r#type() {
Type::Enum => Cow::Owned(format!(
"enumeration({})",
self.resolve_ident(field.type_name())
)),
_ => self.field_type_tag(field),

_ => self.field_type_tag(field, custom_scalar_module_path),
}
}

Expand Down
23 changes: 23 additions & 0 deletions prost-build/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use log::debug;
use log::trace;

use prost::Message;
use prost_types::field_descriptor_proto::Type;
use prost_types::{FileDescriptorProto, FileDescriptorSet};

use crate::code_generator::CodeGenerator;
Expand All @@ -36,6 +37,7 @@ pub struct Config {
pub(crate) message_attributes: PathMap<String>,
pub(crate) enum_attributes: PathMap<String>,
pub(crate) field_attributes: PathMap<String>,
pub(crate) custom_scalar: PathMap<(Type, String)>,
pub(crate) boxed: PathMap<()>,
pub(crate) prost_types: bool,
pub(crate) strip_enum_prefix: bool,
Expand Down Expand Up @@ -375,6 +377,26 @@ impl Config {
self
}

pub fn custom_scalar<M, I, S>(
&mut self,
proto_type: Type,
module_path: M,
paths: I,
) -> &mut Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
M: AsRef<str>,
{
for matcher in paths {
self.custom_scalar.insert(
matcher.as_ref().to_string(),
(proto_type, module_path.as_ref().to_string()),
);
}
self
}

/// Configures the code generator to use the provided service generator.
pub fn service_generator(&mut self, service_generator: Box<dyn ServiceGenerator>) -> &mut Self {
self.service_generator = Some(service_generator);
Expand Down Expand Up @@ -1202,6 +1224,7 @@ impl default::Default for Config {
message_attributes: PathMap::default(),
enum_attributes: PathMap::default(),
field_attributes: PathMap::default(),
custom_scalar: PathMap::default(),
boxed: PathMap::default(),
prost_types: true,
strip_enum_prefix: true,
Expand Down
20 changes: 20 additions & 0 deletions prost-build/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,26 @@ impl<'a> Context<'a> {
false
}

pub fn get_custom_scalar_module_path(
&self,
field: &FieldDescriptorProto,
fq_message_name: &str,
) -> Option<String> {
if matches!(field.r#type(), Type::Message | Type::Group) {
return None;
}
self.config
.custom_scalar
.get_first_field(fq_message_name, field.name())
.and_then(|(ty, interface)| {
if field.r#type() == *ty {
Some(interface.clone())
} else {
None
}
})
}

/// Returns `true` if this message can automatically derive Copy trait.
pub fn can_message_derive_copy(&self, fq_message_name: &str) -> bool {
assert_eq!(".", &fq_message_name[..1]);
Expand Down
35 changes: 18 additions & 17 deletions prost-derive/src/field/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ impl Field {
/// Returns a statement which encodes the map field.
pub fn encode(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
let tag = self.tag;
let key_mod = self.key_ty.module();
let ke = quote!(#prost_path::encoding::#key_mod::encode);
let kl = quote!(#prost_path::encoding::#key_mod::encoded_len);
let key_mod = self.key_ty.encoding_module(prost_path);
let ke = quote!(#key_mod::encode);
let kl = quote!(#key_mod::encoded_len);
let module = self.map_ty.module();
match &self.value_ty {
ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
Expand All @@ -147,9 +147,9 @@ impl Field {
}
}
ValueTy::Scalar(value_ty) => {
let val_mod = value_ty.module();
let ve = quote!(#prost_path::encoding::#val_mod::encode);
let vl = quote!(#prost_path::encoding::#val_mod::encoded_len);
let val_mod = value_ty.encoding_module(prost_path);
let ve = quote!(#val_mod::encode);
let vl = quote!(#val_mod::encoded_len);
quote! {
#prost_path::encoding::#module::encode(
#ke,
Expand Down Expand Up @@ -179,8 +179,8 @@ impl Field {
/// Returns an expression which evaluates to the result of merging a decoded key value pair
/// into the map.
pub fn merge(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
let key_mod = self.key_ty.module();
let km = quote!(#prost_path::encoding::#key_mod::merge);
let key_mod = self.key_ty.encoding_module(prost_path);
let km = quote!(#key_mod::merge);
let module = self.map_ty.module();
match &self.value_ty {
ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
Expand All @@ -197,9 +197,9 @@ impl Field {
}
}
ValueTy::Scalar(value_ty) => {
let val_mod = value_ty.module();
let vm = quote!(#prost_path::encoding::#val_mod::merge);
quote!(#prost_path::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx))
let val_mod = value_ty.encoding_module(prost_path);
let vm = quote!(#val_mod::merge);
quote!(::prost::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx))
}
ValueTy::Message => quote! {
#prost_path::encoding::#module::merge(
Expand All @@ -216,8 +216,8 @@ impl Field {
/// Returns an expression which evaluates to the encoded length of the map.
pub fn encoded_len(&self, prost_path: &Path, ident: TokenStream) -> TokenStream {
let tag = self.tag;
let key_mod = self.key_ty.module();
let kl = quote!(#prost_path::encoding::#key_mod::encoded_len);
let key_mod = self.key_ty.encoding_module(prost_path);
let kl = quote!(#key_mod::encoded_len);
let module = self.map_ty.module();
match &self.value_ty {
ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
Expand All @@ -233,8 +233,8 @@ impl Field {
}
}
ValueTy::Scalar(value_ty) => {
let val_mod = value_ty.module();
let vl = quote!(#prost_path::encoding::#val_mod::encoded_len);
let val_mod = value_ty.encoding_module(prost_path);
let vl = quote!(#val_mod::encoded_len);
quote!(#prost_path::encoding::#module::encoded_len(#kl, #vl, #tag, &#ident))
}
ValueTy::Message => quote! {
Expand All @@ -256,7 +256,7 @@ impl Field {
pub fn methods(&self, prost_path: &Path, ident: &TokenStream) -> Option<TokenStream> {
if let ValueTy::Scalar(scalar::Ty::Enumeration(ty)) = &self.value_ty {
let key_ty = self.key_ty.rust_type(prost_path);
let key_ref_ty = self.key_ty.rust_ref_type();
let key_ref_ty = self.key_ty.rust_ref_type(prost_path);

let get = Ident::new(&format!("get_{ident}"), Span::call_site());
let insert = Ident::new(&format!("insert_{ident}"), Span::call_site());
Expand Down Expand Up @@ -366,7 +366,8 @@ fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> {
| scalar::Ty::Sfixed32
| scalar::Ty::Sfixed64
| scalar::Ty::Bool
| scalar::Ty::String => Ok(ty),
| scalar::Ty::String
| scalar::Ty::CustomScalar(_) => Ok(ty),
_ => bail!("invalid map key type: {s}"),
}
}
Expand Down
2 changes: 1 addition & 1 deletion prost-derive/src/field/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl Field {

pub fn methods(&self, prost_path: &Path, ident: &TokenStream) -> Option<TokenStream> {
match *self {
Field::Scalar(ref scalar) => scalar.methods(ident),
Field::Scalar(ref scalar) => scalar.methods(prost_path, ident),
Field::Map(ref map) => map.methods(prost_path, ident),
_ => None,
}
Expand Down
Loading
Loading