From 7561f6fd4090513a2dfcaeb533d802d85c8aabbb Mon Sep 17 00:00:00 2001 From: Jorge Prendes Date: Wed, 7 May 2025 20:09:41 +0100 Subject: [PATCH] Improve the ergonomics of registering host functions Signed-off-by: Jorge Prendes --- README.md | 11 +-- src/hyperlight_host/benches/benchmarks.rs | 10 +-- .../examples/guest-debugging/main.rs | 10 +-- .../examples/hello-world/main.rs | 10 +-- .../src/func/guest_dispatch.rs | 23 ++---- .../src/func/host_functions.rs | 45 ++++++++++++ .../src/sandbox/uninitialized.rs | 72 +++++++++++-------- .../tests/sandbox_host_tests.rs | 14 ++-- 8 files changed, 109 insertions(+), 86 deletions(-) diff --git a/README.md b/README.md index d1d2a2f65..f1e7c715e 100644 --- a/README.md +++ b/README.md @@ -47,16 +47,11 @@ fn main() -> hyperlight_host::Result<()> { None, // default host print function )?; - // Register a host function - fn sleep_5_secs() -> hyperlight_host::Result<()> { + // Registering a host function makes it available to be called by the guest + uninitialized_sandbox.register("Sleep5Secs", || { thread::sleep(std::time::Duration::from_secs(5)); Ok(()) - } - - let host_function = Arc::new(Mutex::new(sleep_5_secs)); - - // Registering a host function makes it available to be called by the guest - host_function.register(&mut uninitialized_sandbox, "Sleep5Secs")?; + })?; // Note: This function is unused by the guest code below, it's just here for demonstration purposes // Initialize sandbox to be able to call host functions diff --git a/src/hyperlight_host/benches/benchmarks.rs b/src/hyperlight_host/benches/benchmarks.rs index a9aacbff4..d680d484f 100644 --- a/src/hyperlight_host/benches/benchmarks.rs +++ b/src/hyperlight_host/benches/benchmarks.rs @@ -14,12 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ -use std::sync::{Arc, Mutex}; use std::time::Duration; use criterion::{criterion_group, criterion_main, Criterion}; use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; -use hyperlight_host::func::HostFunction; use hyperlight_host::sandbox::{MultiUseSandbox, SandboxConfiguration, UninitializedSandbox}; use hyperlight_host::sandbox_state::sandbox::EvolvableSandbox; use hyperlight_host::sandbox_state::transition::Noop; @@ -110,12 +108,8 @@ fn guest_call_benchmark(c: &mut Criterion) { let mut uninitialized_sandbox = create_uninit_sandbox(); // Define a host function that adds two integers and register it. - fn add(a: i32, b: i32) -> hyperlight_host::Result { - Ok(a + b) - } - let host_function = Arc::new(Mutex::new(add)); - host_function - .register(&mut uninitialized_sandbox, "HostAdd") + uninitialized_sandbox + .register("HostAdd", |a: i32, b: i32| Ok(a + b)) .unwrap(); let multiuse_sandbox: MultiUseSandbox = diff --git a/src/hyperlight_host/examples/guest-debugging/main.rs b/src/hyperlight_host/examples/guest-debugging/main.rs index 3aff4e7a5..d476c7fa5 100644 --- a/src/hyperlight_host/examples/guest-debugging/main.rs +++ b/src/hyperlight_host/examples/guest-debugging/main.rs @@ -14,11 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -use std::sync::{Arc, Mutex}; use std::thread; use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; -use hyperlight_host::func::HostFunction; #[cfg(gdb)] use hyperlight_host::sandbox::config::DebugInfo; use hyperlight_host::sandbox::SandboxConfiguration; @@ -55,14 +53,10 @@ fn main() -> hyperlight_host::Result<()> { )?; // Register a host functions - fn sleep_5_secs() -> hyperlight_host::Result<()> { + uninitialized_sandbox.register("Sleep5Secs", || { thread::sleep(std::time::Duration::from_secs(5)); Ok(()) - } - - let host_function = Arc::new(Mutex::new(sleep_5_secs)); - - host_function.register(&mut uninitialized_sandbox, "Sleep5Secs")?; + })?; // Note: This function is unused, it's just here for demonstration purposes // Initialize sandbox to be able to call host functions diff --git a/src/hyperlight_host/examples/hello-world/main.rs b/src/hyperlight_host/examples/hello-world/main.rs index 2d49c5253..ad1dedbd2 100644 --- a/src/hyperlight_host/examples/hello-world/main.rs +++ b/src/hyperlight_host/examples/hello-world/main.rs @@ -14,11 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. */ -use std::sync::{Arc, Mutex}; use std::thread; use hyperlight_common::flatbuffer_wrappers::function_types::{ParameterValue, ReturnType}; -use hyperlight_host::func::HostFunction; use hyperlight_host::sandbox_state::sandbox::EvolvableSandbox; use hyperlight_host::sandbox_state::transition::Noop; use hyperlight_host::{MultiUseSandbox, UninitializedSandbox}; @@ -35,14 +33,10 @@ fn main() -> hyperlight_host::Result<()> { )?; // Register a host functions - fn sleep_5_secs() -> hyperlight_host::Result<()> { + uninitialized_sandbox.register("Sleep5Secs", || { thread::sleep(std::time::Duration::from_secs(5)); Ok(()) - } - - let host_function = Arc::new(Mutex::new(sleep_5_secs)); - - host_function.register(&mut uninitialized_sandbox, "Sleep5Secs")?; + })?; // Note: This function is unused, it's just here for demonstration purposes // Initialize sandbox to be able to call host functions diff --git a/src/hyperlight_host/src/func/guest_dispatch.rs b/src/hyperlight_host/src/func/guest_dispatch.rs index d2395ef17..a7849bc99 100644 --- a/src/hyperlight_host/src/func/guest_dispatch.rs +++ b/src/hyperlight_host/src/func/guest_dispatch.rs @@ -112,7 +112,6 @@ mod tests { use super::*; use crate::func::call_ctx::MultiUseGuestCallContext; - use crate::func::host_functions::HostFunction; use crate::sandbox::is_hypervisor_present; use crate::sandbox::uninitialized::GuestBinary; use crate::sandbox_state::sandbox::EvolvableSandbox; @@ -152,8 +151,6 @@ mod tests { // First, run to make sure it fails. { - let make_get_pid_syscall_func = Arc::new(Mutex::new(make_get_pid_syscall)); - let mut usbox = UninitializedSandbox::new( GuestBinary::FilePath(simple_guest_as_string().expect("Guest Binary Missing")), None, @@ -162,7 +159,7 @@ mod tests { ) .unwrap(); - make_get_pid_syscall_func.register(&mut usbox, "MakeGetpidSyscall")?; + usbox.register("MakeGetpidSyscall", make_get_pid_syscall)?; let mut sbox: MultiUseSandbox = usbox.evolve(Noop::default())?; @@ -188,8 +185,6 @@ mod tests { // Second, run with allowing `SYS_getpid` #[cfg(feature = "seccomp")] { - let make_get_pid_syscall_func = Arc::new(Mutex::new(make_get_pid_syscall)); - let mut usbox = UninitializedSandbox::new( GuestBinary::FilePath(simple_guest_as_string().expect("Guest Binary Missing")), None, @@ -198,9 +193,9 @@ mod tests { ) .unwrap(); - make_get_pid_syscall_func.register_with_extra_allowed_syscalls( - &mut usbox, + usbox.register_with_extra_allowed_syscalls( "MakeGetpidSyscall", + make_get_pid_syscall, vec![libc::SYS_getpid], )?; // ^^^ note, we are allowing SYS_getpid @@ -453,18 +448,12 @@ mod tests { Ok(()) } - let host_spin_func = Arc::new(Mutex::new(spin)); - #[cfg(any(target_os = "windows", not(feature = "seccomp")))] - host_spin_func.register(&mut usbox, "Spin").unwrap(); + usbox.register("Spin", spin).unwrap(); #[cfg(all(target_os = "linux", feature = "seccomp"))] - host_spin_func - .register_with_extra_allowed_syscalls( - &mut usbox, - "Spin", - vec![libc::SYS_clock_nanosleep], - ) + usbox + .register_with_extra_allowed_syscalls("Spin", spin, vec![libc::SYS_clock_nanosleep]) .unwrap(); let sandbox: MultiUseSandbox = usbox.evolve(Noop::default()).unwrap(); diff --git a/src/hyperlight_host/src/func/host_functions.rs b/src/hyperlight_host/src/func/host_functions.rs index eb2768aa4..39a1d137b 100644 --- a/src/hyperlight_host/src/func/host_functions.rs +++ b/src/hyperlight_host/src/func/host_functions.rs @@ -68,6 +68,15 @@ pub trait HostFunction { ) -> Result<()>; } +/// Tait for types that can be converted into types implementing `HostFunction`. +pub trait IntoHostFunction { + /// Concrete type of the returned host function + type Output: HostFunction; + + /// Convert the type into a host function + fn into_host_function(self) -> Self::Output; +} + macro_rules! impl_host_function { (@count) => { 0 }; (@count $P:ident $(, $R:ident)*) => { @@ -109,6 +118,42 @@ macro_rules! impl_host_function { } } + impl IntoHostFunction for F + where + F: FnMut($($P),*) -> Result + Send + 'static, + Arc>: HostFunction, + { + type Output = Arc>; + + fn into_host_function(self) -> Self::Output { + Arc::new(Mutex::new(self)) + } + } + + impl IntoHostFunction for Arc> + where + F: FnMut($($P),*) -> Result + Send + 'static, + Arc>: HostFunction, + { + type Output = Arc>; + + fn into_host_function(self) -> Self::Output { + self + } + } + + impl IntoHostFunction for &Arc> + where + F: FnMut($($P),*) -> Result + Send + 'static, + Arc>: HostFunction, + { + type Output = Arc>; + + fn into_host_function(self) -> Self::Output { + self.clone() + } + } + fn register_host_function( self_: Arc>, sandbox: &mut UninitializedSandbox, diff --git a/src/hyperlight_host/src/sandbox/uninitialized.rs b/src/hyperlight_host/src/sandbox/uninitialized.rs index ed0d08982..fbf2f7b1d 100644 --- a/src/hyperlight_host/src/sandbox/uninitialized.rs +++ b/src/hyperlight_host/src/sandbox/uninitialized.rs @@ -30,7 +30,7 @@ use super::mem_mgr::MemMgrWrapper; use super::run_options::SandboxRunOptions; use super::uninitialized_evolve::evolve_impl_multi_use; use crate::error::HyperlightError::GuestBinaryShouldBeAFile; -use crate::func::host_functions::HostFunction; +use crate::func::host_functions::{HostFunction, IntoHostFunction}; use crate::mem::exe::ExeInfo; use crate::mem::mgr::{SandboxMemoryManager, STACK_COOKIE_LEN}; use crate::mem::shared_mem::ExclusiveSharedMemory; @@ -224,24 +224,15 @@ impl UninitializedSandbox { // If we were passed a writer for host print register it otherwise use the default. match host_print_writer { Some(writer_func) => { - #[allow(clippy::arc_with_non_send_sync)] - let writer_func = Arc::new(Mutex::new(writer_func)); - #[cfg(any(target_os = "windows", not(feature = "seccomp")))] - writer_func - .try_lock() - .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))? - .register(&mut sandbox, "HostPrint")?; + writer_func.register(&mut sandbox, "HostPrint")?; #[cfg(all(target_os = "linux", feature = "seccomp"))] - writer_func - .try_lock() - .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))? - .register_with_extra_allowed_syscalls( - &mut sandbox, - "HostPrint", - extra_allowed_syscalls_for_writer_func, - )?; + writer_func.register_with_extra_allowed_syscalls( + &mut sandbox, + "HostPrint", + extra_allowed_syscalls_for_writer_func, + )?; } None => { let default_writer = Arc::new(Mutex::new(default_writer_func)); @@ -309,6 +300,31 @@ impl UninitializedSandbox { pub fn set_max_guest_log_level(&mut self, log_level: LevelFilter) { self.max_guest_log_level = Some(log_level); } + + /// Register a host function with the given name in the sandbox. + pub fn register(&mut self, name: impl AsRef, host_func: F) -> Result<()> + where + F: IntoHostFunction, + { + host_func.into_host_function().register(self, name.as_ref()) + } + + /// Register the host function with the given name in the sandbox, allowing extra syscalls. + #[cfg(all(feature = "seccomp", target_os = "linux"))] + pub fn register_with_extra_allowed_syscalls( + &mut self, + name: impl AsRef, + host_func: F, + extra_allowed_syscalls: impl IntoIterator, + ) -> Result<()> + where + F: IntoHostFunction, + { + let extra_allowed_syscalls: Vec<_> = extra_allowed_syscalls.into_iter().collect(); + host_func + .into_host_function() + .register_with_extra_allowed_syscalls(self, name.as_ref(), extra_allowed_syscalls) + } } // Check to see if the current version of Windows is supported // Hyperlight is only supported on Windows 11 and Windows Server 2022 and later @@ -355,7 +371,6 @@ mod tests { use tracing_core::Subscriber; use uuid::Uuid; - use crate::func::HostFunction; use crate::sandbox::uninitialized::GuestBinary; use crate::sandbox::SandboxConfiguration; use crate::sandbox_state::sandbox::EvolvableSandbox; @@ -516,9 +531,8 @@ mod tests { // simple register + call { let mut usbox = uninitialized_sandbox(); - let test0 = |arg: i32| -> Result { Ok(arg + 1) }; - let test_func0 = Arc::new(Mutex::new(test0)); - test_func0.register(&mut usbox, "test0").unwrap(); + + usbox.register("test0", |arg: i32| Ok(arg + 1)).unwrap(); let sandbox: Result = usbox.evolve(Noop::default()); assert!(sandbox.is_ok()); @@ -542,9 +556,8 @@ mod tests { // multiple parameters register + call { let mut usbox = uninitialized_sandbox(); - let test1 = |arg1: i32, arg2: i32| -> Result { Ok(arg1 + arg2) }; - let test_func1 = Arc::new(Mutex::new(test1)); - test_func1.register(&mut usbox, "test1").unwrap(); + + usbox.register("test1", |a: i32, b: i32| Ok(a + b)).unwrap(); let sandbox: Result = usbox.evolve(Noop::default()); assert!(sandbox.is_ok()); @@ -571,12 +584,13 @@ mod tests { // incorrect arguments register + call { let mut usbox = uninitialized_sandbox(); - let test2 = |arg1: String| -> Result<()> { - println!("test2 called: {}", arg1); - Ok(()) - }; - let test_func2 = Arc::new(Mutex::new(test2)); - test_func2.register(&mut usbox, "test2").unwrap(); + + usbox + .register("test2", |msg: String| { + println!("test2 called: {}", msg); + Ok(()) + }) + .unwrap(); let sandbox: Result = usbox.evolve(Noop::default()); assert!(sandbox.is_ok()); diff --git a/src/hyperlight_host/tests/sandbox_host_tests.rs b/src/hyperlight_host/tests/sandbox_host_tests.rs index 64173eefd..e563c872e 100644 --- a/src/hyperlight_host/tests/sandbox_host_tests.rs +++ b/src/hyperlight_host/tests/sandbox_host_tests.rs @@ -18,7 +18,7 @@ use core::f64; use std::sync::{Arc, Mutex}; use common::new_uninit; -use hyperlight_host::func::{HostFunction, ParameterValue, ReturnType, ReturnValue}; +use hyperlight_host::func::{ParameterValue, ReturnType, ReturnValue}; use hyperlight_host::sandbox::SandboxConfiguration; use hyperlight_host::sandbox_state::sandbox::EvolvableSandbox; use hyperlight_host::sandbox_state::transition::Noop; @@ -545,16 +545,15 @@ fn callback_test_helper() -> Result<()> { // create host function let vec = Arc::new(Mutex::new(vec![])); let vec_cloned = vec.clone(); - let host_func1 = Arc::new(Mutex::new(move |msg: String| { + + sandbox.register("HostMethod1", move |msg: String| { let len = msg.len(); vec_cloned .try_lock() .map_err(|e| new_error!("Error locking at {}:{}: {}", file!(), line!(), e))? .push(msg); Ok(len as i32) - })); - - host_func1.register(&mut sandbox, "HostMethod1").unwrap(); + })?; // call guest function that calls host function let mut init_sandbox: MultiUseSandbox = sandbox.evolve(Noop::default())?; @@ -611,10 +610,9 @@ fn host_function_error() -> Result<()> { // when a host function returns an error, an infinite loop is created. for mut sandbox in get_callbackguest_uninit_sandboxes(None).into_iter().take(1) { // create host function - let host_func1 = Arc::new(Mutex::new(|_msg: String| -> Result { + sandbox.register("HostMethod1", |_: String| -> Result { Err(new_error!("Host function error!")) - })); - host_func1.register(&mut sandbox, "HostMethod1").unwrap(); + })?; // call guest function that calls host function let mut init_sandbox: MultiUseSandbox = sandbox.evolve(Noop::default())?;