diff --git a/src/hyperlight_common/src/flatbuffer_wrappers/function_types.rs b/src/hyperlight_common/src/flatbuffer_wrappers/function_types.rs index d6016db19..f0e3effb5 100644 --- a/src/hyperlight_common/src/flatbuffer_wrappers/function_types.rs +++ b/src/hyperlight_common/src/flatbuffer_wrappers/function_types.rs @@ -99,7 +99,7 @@ pub enum ReturnValue { /// bool Bool(bool), /// () - Void, + Void(()), /// Vec VecBytes(Vec), } @@ -508,7 +508,7 @@ impl TryFrom for () { #[cfg_attr(feature = "tracing", instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace"))] fn try_from(value: ReturnValue) -> Result { match value { - ReturnValue::Void => Ok(()), + ReturnValue::Void(()) => Ok(()), _ => { bail!("Unexpected return value type: {:?}", value) } @@ -570,7 +570,7 @@ impl TryFrom> for ReturnValue { }; Ok(ReturnValue::String(hlstring.unwrap_or("".to_string()))) } - FbReturnValue::hlvoid => Ok(ReturnValue::Void), + FbReturnValue::hlvoid => Ok(ReturnValue::Void(())), FbReturnValue::hlsizeprefixedbuffer => { let hlvecbytes = match function_call_result_fb.return_value_as_hlsizeprefixedbuffer() { @@ -724,7 +724,7 @@ impl TryFrom<&ReturnValue> for Vec { builder.finish_size_prefixed(function_call_result, None); builder.finished_data().to_vec() } - ReturnValue::Void => { + ReturnValue::Void(()) => { let hlvoid = hlvoid::create(&mut builder, &hlvoidArgs {}); let function_call_result = FbFunctionCallResult::create( &mut builder, diff --git a/src/hyperlight_host/src/func/host_functions.rs b/src/hyperlight_host/src/func/host_functions.rs index 5e9e2b213..8656cd4b0 100644 --- a/src/hyperlight_host/src/func/host_functions.rs +++ b/src/hyperlight_host/src/func/host_functions.rs @@ -169,7 +169,7 @@ macro_rules! impl_host_function { let cloned = self_.clone(); let func = Box::new(move |args: Vec| { let ($($P,)*) = match <[ParameterValue; N]>::try_from(args) { - Ok([$($P,)*]) => ($($P::get_inner($P)?,)*), + Ok([$($P,)*]) => ($($P::from_value($P)?,)*), Err(args) => { log_then_return!(UnexpectedNoOfArguments(args.len(), N)); } }; @@ -178,10 +178,10 @@ macro_rules! impl_host_function { .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))?( $($P),* )?; - Ok(result.get_hyperlight_value()) + Ok(result.into_value()) }); - let parameter_types = Some(vec![$($P::get_hyperlight_type()),*]); + let parameter_types = Some(vec![$($P::TYPE),*]); if let Some(_eas) = extra_allowed_syscalls { if cfg!(all(feature = "seccomp", target_os = "linux")) { @@ -196,7 +196,7 @@ macro_rules! impl_host_function { &HostFunctionDefinition::new( name.to_string(), parameter_types, - R::get_hyperlight_type(), + R::TYPE, ), HyperlightFunction::new(func), _eas, @@ -216,7 +216,7 @@ macro_rules! impl_host_function { &HostFunctionDefinition::new( name.to_string(), parameter_types, - R::get_hyperlight_type(), + R::TYPE, ), HyperlightFunction::new(func), )?; diff --git a/src/hyperlight_host/src/func/param_type.rs b/src/hyperlight_host/src/func/param_type.rs index 2b1e9413b..4282304e8 100644 --- a/src/hyperlight_host/src/func/param_type.rs +++ b/src/hyperlight_host/src/func/param_type.rs @@ -26,167 +26,53 @@ use crate::{log_then_return, Result}; /// For each parameter type Hyperlight supports in host functions, we /// provide an implementation for `SupportedParameterType` pub trait SupportedParameterType: Sized { - /// Get the underlying Hyperlight parameter type representing this - /// `SupportedParameterType` - fn get_hyperlight_type() -> ParameterType; + /// The underlying Hyperlight parameter type representing this `SupportedParameterType` + const TYPE: ParameterType; + /// Get the underling Hyperlight parameter value representing this /// `SupportedParameterType` - fn get_hyperlight_value(&self) -> ParameterValue; + fn into_value(self) -> ParameterValue; /// Get the actual inner value of this `SupportedParameterType` - fn get_inner(a: ParameterValue) -> Result; + fn from_value(value: ParameterValue) -> Result; } // We can then implement these traits for each type that Hyperlight supports as a parameter or return type -impl SupportedParameterType for String { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ParameterType { - ParameterType::String - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ParameterValue { - ParameterValue::String(self.clone()) - } - - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ParameterValue) -> Result { - match a { - ParameterValue::String(i) => Ok(i), - other => { - log_then_return!(ParameterValueConversionFailure(other.clone(), "String")); - } - } - } -} - -impl SupportedParameterType for i32 { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ParameterType { - ParameterType::Int - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ParameterValue { - ParameterValue::Int(*self) - } - - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ParameterValue) -> Result { - match a { - ParameterValue::Int(i) => Ok(i), - other => { - log_then_return!(ParameterValueConversionFailure(other.clone(), "i32")); - } - } - } -} - -impl SupportedParameterType for u32 { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ParameterType { - ParameterType::UInt - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ParameterValue { - ParameterValue::UInt(*self) - } - - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ParameterValue) -> Result { - match a { - ParameterValue::UInt(ui) => Ok(ui), - other => { - log_then_return!(ParameterValueConversionFailure(other.clone(), "u32")); - } - } - } +macro_rules! for_each_param_type { + ($macro:ident) => { + $macro!(String, String); + $macro!(i32, Int); + $macro!(u32, UInt); + $macro!(i64, Long); + $macro!(u64, ULong); + $macro!(bool, Bool); + $macro!(Vec, VecBytes); + }; } -impl SupportedParameterType for i64 { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ParameterType { - ParameterType::Long - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ParameterValue { - ParameterValue::Long(*self) - } +macro_rules! impl_supported_param_type { + ($type:ty, $enum:ident) => { + impl SupportedParameterType for $type { + const TYPE: ParameterType = ParameterType::$enum; - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ParameterValue) -> Result { - match a { - ParameterValue::Long(l) => Ok(l), - other => { - log_then_return!(ParameterValueConversionFailure(other.clone(), "i64")); + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + fn into_value(self) -> ParameterValue { + ParameterValue::$enum(self) } - } - } -} - -impl SupportedParameterType for u64 { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ParameterType { - ParameterType::ULong - } - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ParameterValue { - ParameterValue::ULong(*self) - } - - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ParameterValue) -> Result { - match a { - ParameterValue::ULong(ul) => Ok(ul), - other => { - log_then_return!(ParameterValueConversionFailure(other.clone(), "u64")); + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + fn from_value(value: ParameterValue) -> Result { + match value { + ParameterValue::$enum(i) => Ok(i), + other => { + log_then_return!(ParameterValueConversionFailure( + other.clone(), + stringify!($type) + )); + } + } } } - } + }; } -impl SupportedParameterType for bool { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ParameterType { - ParameterType::Bool - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ParameterValue { - ParameterValue::Bool(*self) - } - - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ParameterValue) -> Result { - match a { - ParameterValue::Bool(i) => Ok(i), - other => { - log_then_return!(ParameterValueConversionFailure(other.clone(), "bool")); - } - } - } -} - -impl SupportedParameterType for Vec { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ParameterType { - ParameterType::VecBytes - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ParameterValue { - ParameterValue::VecBytes(self.clone()) - } - - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ParameterValue) -> Result> { - match a { - ParameterValue::VecBytes(i) => Ok(i), - other => { - log_then_return!(ParameterValueConversionFailure(other.clone(), "Vec")); - } - } - } -} +for_each_param_type!(impl_supported_param_type); diff --git a/src/hyperlight_host/src/func/ret_type.rs b/src/hyperlight_host/src/func/ret_type.rs index 0b18572c8..756113ba3 100644 --- a/src/hyperlight_host/src/func/ret_type.rs +++ b/src/hyperlight_host/src/func/ret_type.rs @@ -22,188 +22,53 @@ use crate::{log_then_return, Result}; /// This is a marker trait that is used to indicate that a type is a valid Hyperlight return type. pub trait SupportedReturnType: Sized { - /// Gets the return type of the supported return value - fn get_hyperlight_type() -> ReturnType; + /// The return type of the supported return value + const TYPE: ReturnType; /// Gets the value of the supported return value - fn get_hyperlight_value(&self) -> ReturnValue; + fn into_value(self) -> ReturnValue; /// Gets the inner value of the supported return type - fn get_inner(a: ReturnValue) -> Result; + fn from_value(value: ReturnValue) -> Result; } -impl SupportedReturnType for () { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ReturnType { - ReturnType::Void - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ReturnValue { - ReturnValue::Void - } - - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ReturnValue) -> Result<()> { - match a { - ReturnValue::Void => Ok(()), - other => { - log_then_return!(ReturnValueConversionFailure(other.clone(), "()")); - } - } - } -} - -impl SupportedReturnType for String { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ReturnType { - ReturnType::String - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ReturnValue { - ReturnValue::String(self.clone()) - } - - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ReturnValue) -> Result { - match a { - ReturnValue::String(i) => Ok(i), - other => { - log_then_return!(ReturnValueConversionFailure(other.clone(), "String")); - } - } - } -} - -impl SupportedReturnType for i32 { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ReturnType { - ReturnType::Int - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ReturnValue { - ReturnValue::Int(*self) - } - - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ReturnValue) -> Result { - match a { - ReturnValue::Int(i) => Ok(i), - other => { - log_then_return!(ReturnValueConversionFailure(other.clone(), "i32")); - } - } - } -} - -impl SupportedReturnType for u32 { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ReturnType { - ReturnType::UInt - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ReturnValue { - ReturnValue::UInt(*self) - } - - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ReturnValue) -> Result { - match a { - ReturnValue::UInt(u) => Ok(u), - other => { - log_then_return!(ReturnValueConversionFailure(other.clone(), "u32")); - } - } - } +macro_rules! for_each_return_type { + ($macro:ident) => { + $macro!((), Void); + $macro!(String, String); + $macro!(i32, Int); + $macro!(u32, UInt); + $macro!(i64, Long); + $macro!(u64, ULong); + $macro!(bool, Bool); + $macro!(Vec, VecBytes); + }; } -impl SupportedReturnType for i64 { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ReturnType { - ReturnType::Long - } +macro_rules! impl_supported_return_type { + ($type:ty, $enum:ident) => { + impl SupportedReturnType for $type { + const TYPE: ReturnType = ReturnType::$enum; - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ReturnValue { - ReturnValue::Long(*self) - } - - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ReturnValue) -> Result { - match a { - ReturnValue::Long(l) => Ok(l), - other => { - log_then_return!(ReturnValueConversionFailure(other.clone(), "i64")); + #[instrument(skip_all, parent = Span::current(), level= "Trace")] + fn into_value(self) -> ReturnValue { + ReturnValue::$enum(self) } - } - } -} - -impl SupportedReturnType for u64 { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ReturnType { - ReturnType::ULong - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ReturnValue { - ReturnValue::ULong(*self) - } - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ReturnValue) -> Result { - match a { - ReturnValue::ULong(ul) => Ok(ul), - other => { - log_then_return!(ReturnValueConversionFailure(other.clone(), "u64")); + #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] + fn from_value(value: ReturnValue) -> Result { + match value { + ReturnValue::$enum(i) => Ok(i), + other => { + log_then_return!(ReturnValueConversionFailure( + other.clone(), + stringify!($type) + )); + } + } } } - } + }; } -impl SupportedReturnType for bool { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ReturnType { - ReturnType::Bool - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ReturnValue { - ReturnValue::Bool(*self) - } - - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ReturnValue) -> Result { - match a { - ReturnValue::Bool(i) => Ok(i), - other => { - log_then_return!(ReturnValueConversionFailure(other.clone(), "bool")); - } - } - } -} - -impl SupportedReturnType for Vec { - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_type() -> ReturnType { - ReturnType::VecBytes - } - - #[instrument(skip_all, parent = Span::current(), level= "Trace")] - fn get_hyperlight_value(&self) -> ReturnValue { - ReturnValue::VecBytes(self.clone()) - } - - #[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")] - fn get_inner(a: ReturnValue) -> Result> { - match a { - ReturnValue::VecBytes(i) => Ok(i), - other => { - log_then_return!(ReturnValueConversionFailure(other.clone(), "Vec")); - } - } - } -} +for_each_return_type!(impl_supported_return_type);