Skip to content

Commit e06d40a

Browse files
committed
guest: make call_host_function generic to avoid two steps to retrieve return value
Signed-off-by: Doru Blânzeanu <[email protected]>
1 parent d353fc7 commit e06d40a

File tree

7 files changed

+55
-36
lines changed

7 files changed

+55
-36
lines changed

src/hyperlight_guest/src/host_function_call.rs

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ use crate::shared_input_data::try_pop_shared_input_data_into;
3232
use crate::shared_output_data::push_shared_output_data;
3333

3434
/// Get a return value from a host function call.
35-
/// This usually requires a host function to be called first using `call_host_function`.
35+
/// This usually requires a host function to be called first using `call_host_function_internal`.
36+
/// When calling `call_host_function<T>`, this function is called internally to get the return
37+
/// value.
3638
pub fn get_host_return_value<T: TryFrom<ReturnValue>>() -> Result<T> {
3739
let return_value = try_pop_shared_input_data_into::<ReturnValue>()
3840
.expect("Unable to deserialize a return value from host");
@@ -47,10 +49,12 @@ pub fn get_host_return_value<T: TryFrom<ReturnValue>>() -> Result<T> {
4749
})
4850
}
4951

50-
// TODO: Make this generic, return a Result<T, ErrorCode> this should allow callers to call this function and get the result type they expect
51-
// without having to do the conversion themselves
52-
53-
pub fn call_host_function(
52+
/// Internal function to call a host function without generic type parameters.
53+
/// This is used by both the Rust and C APIs to reduce code duplication.
54+
///
55+
/// This function doesn't return the host function result directly, instead it just
56+
/// performs the call. The result must be obtained by calling `get_host_return_value`.
57+
pub fn call_host_function_internal(
5458
function_name: &str,
5559
parameters: Option<Vec<ParameterValue>>,
5660
return_type: ReturnType,
@@ -73,6 +77,20 @@ pub fn call_host_function(
7377
Ok(())
7478
}
7579

80+
/// Call a host function with the given parameters and return type.
81+
/// This function serializes the function call and its parameters,
82+
/// sends it to the host, and then retrieves the return value.
83+
///
84+
/// The return value is deserialized into the specified type `T`.
85+
pub fn call_host_function<T: TryFrom<ReturnValue>>(
86+
function_name: &str,
87+
parameters: Option<Vec<ParameterValue>>,
88+
return_type: ReturnType,
89+
) -> Result<T> {
90+
call_host_function_internal(function_name, parameters, return_type)?;
91+
get_host_return_value::<T>()
92+
}
93+
7694
pub fn outb(port: u16, data: &[u8]) {
7795
unsafe {
7896
let mut i = 0;
@@ -109,13 +127,13 @@ pub fn debug_print(msg: &str) {
109127
/// existence of the input and output memory regions.
110128
pub fn print_output_with_host_print(function_call: &FunctionCall) -> Result<Vec<u8>> {
111129
if let ParameterValue::String(message) = function_call.parameters.clone().unwrap()[0].clone() {
112-
call_host_function(
130+
let res = call_host_function::<i32>(
113131
"HostPrint",
114132
Some(Vec::from(&[ParameterValue::String(message.to_string())])),
115133
ReturnType::Int,
116134
)?;
117-
let res_i = get_host_return_value::<i32>()?;
118-
Ok(get_flatbuffer_result(res_i))
135+
136+
Ok(get_flatbuffer_result(res))
119137
} else {
120138
Err(HyperlightGuestError::new(
121139
ErrorCode::GuestError,

src/hyperlight_guest/src/print.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,11 @@ pub unsafe extern "C" fn _putchar(c: c_char) {
5555
.expect("Failed to convert buffer to string")
5656
};
5757

58-
call_host_function(
58+
// HostPrint returns an i32, but we don't care about the return value
59+
let _ = call_host_function::<i32>(
5960
"HostPrint",
6061
Some(Vec::from(&[ParameterValue::String(str)])),
61-
ReturnType::Void,
62+
ReturnType::Int,
6263
)
6364
.expect("Failed to call HostPrint");
6465

src/hyperlight_guest_capi/src/dispatch.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode;
1010
use hyperlight_guest::error::{HyperlightGuestError, Result};
1111
use hyperlight_guest::guest_function_definition::GuestFunctionDefinition;
1212
use hyperlight_guest::guest_function_register::GuestFunctionRegister;
13-
use hyperlight_guest::host_function_call::call_host_function;
13+
use hyperlight_guest::host_function_call::call_host_function_internal;
1414

1515
use crate::types::{FfiFunctionCall, FfiVec};
1616
static mut REGISTERED_C_GUEST_FUNCTIONS: GuestFunctionRegister = GuestFunctionRegister::new();
@@ -89,5 +89,9 @@ pub extern "C" fn hl_call_host_function(function_call: &FfiFunctionCall) {
8989
let parameters = unsafe { function_call.copy_parameters() };
9090
let func_name = unsafe { function_call.copy_function_name() };
9191
let return_type = unsafe { function_call.copy_return_type() };
92-
let _ = call_host_function(&func_name, Some(parameters), return_type);
92+
93+
// Use the non-generic internal implementation
94+
// The C API will then call specific getter functions to fetch the properly typed return value
95+
let _ = call_host_function_internal(&func_name, Some(parameters), return_type)
96+
.expect("Failed to call host function");
9397
}

src/hyperlight_host/src/func/utils.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
///
55
/// Usage:
66
/// ```rust
7+
/// use hyperlight_host::func::for_each_tuple;
8+
///
79
/// macro_rules! my_macro {
810
/// ([$count:expr] ($($name:ident: $type:ident),*)) => {
911
/// // $count is the arity of the tuple

src/hyperlight_host/src/sandbox/initialized_multi_use.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ impl MultiUseSandbox {
112112
/// let u_sbox = UninitializedSandbox::new(
113113
/// GuestBinary::FilePath("some_guest_binary".to_string()),
114114
/// None,
115-
/// None,
116115
/// ).unwrap();
117116
/// let sbox: MultiUseSandbox = u_sbox.evolve(Noop::default()).unwrap();
118117
/// // Next, create a new call context from the single-use sandbox.

src/tests/rust_guests/callbackguest/src/main.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ use hyperlight_common::flatbuffer_wrappers::util::get_flatbuffer_result;
3434
use hyperlight_guest::error::{HyperlightGuestError, Result};
3535
use hyperlight_guest::guest_function_definition::GuestFunctionDefinition;
3636
use hyperlight_guest::guest_function_register::register_function;
37-
use hyperlight_guest::host_function_call::{
38-
call_host_function, get_host_return_value, print_output_with_host_print,
39-
};
37+
use hyperlight_guest::host_function_call::{call_host_function, print_output_with_host_print};
4038
use hyperlight_guest::logging::log_message;
4139

4240
fn send_message_to_host_method(
@@ -45,15 +43,13 @@ fn send_message_to_host_method(
4543
message: &str,
4644
) -> Result<Vec<u8>> {
4745
let message = format!("{}{}", guest_message, message);
48-
call_host_function(
46+
let res = call_host_function::<i32>(
4947
method_name,
5048
Some(Vec::from(&[ParameterValue::String(message.to_string())])),
5149
ReturnType::Int,
5250
)?;
5351

54-
let result = get_host_return_value::<i32>()?;
55-
56-
Ok(get_flatbuffer_result(result))
52+
Ok(get_flatbuffer_result(res))
5753
}
5854

5955
fn guest_function(function_call: &FunctionCall) -> Result<Vec<u8>> {
@@ -101,7 +97,7 @@ fn guest_function3(function_call: &FunctionCall) -> Result<Vec<u8>> {
10197
}
10298

10399
fn guest_function4(_: &FunctionCall) -> Result<Vec<u8>> {
104-
call_host_function(
100+
call_host_function::<()>(
105101
"HostMethod4",
106102
Some(Vec::from(&[ParameterValue::String(
107103
"Hello from GuestFunction4".to_string(),
@@ -157,7 +153,7 @@ fn call_error_method(function_call: &FunctionCall) -> Result<Vec<u8>> {
157153
}
158154

159155
fn call_host_spin(_: &FunctionCall) -> Result<Vec<u8>> {
160-
call_host_function("Spin", None, ReturnType::Void)?;
156+
call_host_function::<()>("Spin", None, ReturnType::Void)?;
161157
Ok(get_flatbuffer_result(()))
162158
}
163159

src/tests/rust_guests/simpleguest/src/main.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ use hyperlight_guest::entrypoint::{abort_with_code, abort_with_code_and_message}
4545
use hyperlight_guest::error::{HyperlightGuestError, Result};
4646
use hyperlight_guest::guest_function_definition::GuestFunctionDefinition;
4747
use hyperlight_guest::guest_function_register::register_function;
48-
use hyperlight_guest::host_function_call::{call_host_function, get_host_return_value};
48+
use hyperlight_guest::host_function_call::{call_host_function, call_host_function_internal};
4949
use hyperlight_guest::memory::malloc;
5050
use hyperlight_guest::{logging, MIN_STACK_ADDRESS};
5151
use log::{error, LevelFilter};
@@ -86,13 +86,13 @@ fn echo_float(function_call: &FunctionCall) -> Result<Vec<u8>> {
8686
}
8787

8888
fn print_output(message: &str) -> Result<Vec<u8>> {
89-
call_host_function(
89+
let res = call_host_function::<i32>(
9090
"HostPrint",
9191
Some(Vec::from(&[ParameterValue::String(message.to_string())])),
9292
ReturnType::Int,
9393
)?;
94-
let result = get_host_return_value::<i32>()?;
95-
Ok(get_flatbuffer_result(result))
94+
95+
Ok(get_flatbuffer_result(res))
9696
}
9797

9898
fn simple_print_output(function_call: &FunctionCall) -> Result<Vec<u8>> {
@@ -679,9 +679,7 @@ fn add_to_static_and_fail(_: &FunctionCall) -> Result<Vec<u8>> {
679679

680680
fn violate_seccomp_filters(function_call: &FunctionCall) -> Result<Vec<u8>> {
681681
if function_call.parameters.is_none() {
682-
call_host_function("MakeGetpidSyscall", None, ReturnType::ULong)?;
683-
684-
let res = get_host_return_value::<u64>()?;
682+
let res = call_host_function::<u64>("MakeGetpidSyscall", None, ReturnType::ULong)?;
685683

686684
Ok(get_flatbuffer_result(res))
687685
} else {
@@ -697,14 +695,11 @@ fn add(function_call: &FunctionCall) -> Result<Vec<u8>> {
697695
function_call.parameters.clone().unwrap()[0].clone(),
698696
function_call.parameters.clone().unwrap()[1].clone(),
699697
) {
700-
call_host_function(
698+
let res = call_host_function::<i32>(
701699
"HostAdd",
702700
Some(Vec::from(&[ParameterValue::Int(a), ParameterValue::Int(b)])),
703701
ReturnType::Int,
704702
)?;
705-
706-
let res = get_host_return_value::<i32>()?;
707-
708703
Ok(get_flatbuffer_result(res))
709704
} else {
710705
Err(HyperlightGuestError::new(
@@ -1156,12 +1151,11 @@ pub fn guest_dispatch_function(function_call: FunctionCall) -> Result<Vec<u8>> {
11561151
1,
11571152
);
11581153

1159-
call_host_function(
1154+
let result = call_host_function::<i32>(
11601155
"HostPrint",
11611156
Some(Vec::from(&[ParameterValue::String(message.to_string())])),
11621157
ReturnType::Int,
11631158
)?;
1164-
let result = get_host_return_value::<i32>()?;
11651159
let function_name = function_call.function_name.clone();
11661160
let param_len = function_call.parameters.clone().unwrap_or_default().len();
11671161
let call_type = function_call.function_call_type().clone();
@@ -1195,7 +1189,12 @@ fn fuzz_host_function(func: FunctionCall) -> Result<Vec<u8>> {
11951189
))
11961190
}
11971191
};
1198-
call_host_function(&host_func_name, Some(params), func.expected_return_type)
1192+
1193+
// Because we do not know at compile time the actual return type of the host function to be called
1194+
// we cannot use the `call_host_function<T>` generic function.
1195+
// We need to use the `call_host_function_internal` function that does not retrieve the return
1196+
// value
1197+
call_host_function_internal(&host_func_name, Some(params), func.expected_return_type)
11991198
.expect("failed to call host function");
12001199
Ok(get_flatbuffer_result(()))
12011200
}

0 commit comments

Comments
 (0)