Skip to content

Refactor HostFuncsWrapper #480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 95 additions & 93 deletions src/hyperlight_host/src/sandbox/host_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,25 @@ use crate::{new_error, Result};

#[derive(Default, Clone)]
/// A Wrapper around details of functions exposed by the Host
pub struct HostFuncsWrapper {
functions_map: HashMap<String, (HyperlightFunction, Option<Vec<ExtraAllowedSyscall>>)>,
pub struct FunctionRegistry {
functions_map: HashMap<String, FunctionEntry>,
}

impl HostFuncsWrapper {
#[derive(Clone)]
pub struct FunctionEntry {
pub function: HyperlightFunction,
pub extra_allowed_syscalls: Option<Vec<ExtraAllowedSyscall>>,
}

impl FunctionRegistry {
/// Register a host function with the sandbox.
#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
pub(crate) fn register_host_function(
&mut self,
name: String,
func: HyperlightFunction,
) -> Result<()> {
register_host_function_helper(self, name, func, None)
self.register_host_function_helper(name, func, None)
}

/// Register a host function with the sandbox, with a list of extra syscalls
Expand All @@ -53,7 +59,7 @@ impl HostFuncsWrapper {
func: HyperlightFunction,
extra_allowed_syscalls: Vec<ExtraAllowedSyscall>,
) -> Result<()> {
register_host_function_helper(self, name, func, Some(extra_allowed_syscalls))
self.register_host_function_helper(name, func, Some(extra_allowed_syscalls))
}

/// Assuming a host function called `"HostPrint"` exists, and takes a
Expand All @@ -63,11 +69,7 @@ impl HostFuncsWrapper {
/// and `Err` otherwise.
#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
pub(super) fn host_print(&mut self, msg: String) -> Result<i32> {
let res = call_host_func_impl(
&self.functions_map,
"HostPrint",
vec![ParameterValue::String(msg)],
)?;
let res = self.call_host_func_impl("HostPrint", vec![ParameterValue::String(msg)])?;
res.try_into()
.map_err(|_| HostFunctionNotFound("HostPrint".to_string()))
}
Expand All @@ -84,97 +86,45 @@ impl HostFuncsWrapper {
name: &str,
args: Vec<ParameterValue>,
) -> Result<ReturnValue> {
call_host_func_impl(&self.functions_map, name, args)
self.call_host_func_impl(name, args)
}
}

fn register_host_function_helper(
self_: &mut HostFuncsWrapper,
name: String,
func: HyperlightFunction,
extra_allowed_syscalls: Option<Vec<ExtraAllowedSyscall>>,
) -> Result<()> {
if let Some(_syscalls) = extra_allowed_syscalls {
#[cfg(all(feature = "seccomp", target_os = "linux"))]
self_.functions_map.insert(name, (func, Some(_syscalls)));

fn register_host_function_helper(
&mut self,
name: String,
function: HyperlightFunction,
extra_allowed_syscalls: Option<Vec<ExtraAllowedSyscall>>,
) -> Result<()> {
#[cfg(not(all(feature = "seccomp", target_os = "linux")))]
return Err(new_error!(
"Extra syscalls are only supported on Linux with seccomp"
));
} else {
self_.functions_map.insert(name, (func, None));
if extra_allowed_syscalls.is_some() {
return Err(new_error!(
"Extra syscalls are only supported on Linux with seccomp"
));
}
self.functions_map.insert(
name,
FunctionEntry {
function,
extra_allowed_syscalls,
},
);
Ok(())
}

Ok(())
}

#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
fn call_host_func_impl(
host_funcs: &HashMap<String, (HyperlightFunction, Option<Vec<ExtraAllowedSyscall>>)>,
name: &str,
args: Vec<ParameterValue>,
) -> Result<ReturnValue> {
// Inner function containing the common logic
fn call_func(
host_funcs: &HashMap<String, (HyperlightFunction, Option<Vec<ExtraAllowedSyscall>>)>,
name: &str,
args: Vec<ParameterValue>,
) -> Result<ReturnValue> {
let func_with_syscalls = host_funcs
#[instrument(err(Debug), skip_all, parent = Span::current(), level = "Trace")]
fn call_host_func_impl(&self, name: &str, args: Vec<ParameterValue>) -> Result<ReturnValue> {
let FunctionEntry {
function,
extra_allowed_syscalls,
} = self
.functions_map
.get(name)
.ok_or_else(|| HostFunctionNotFound(name.to_string()))?;

let func = func_with_syscalls.0.clone();

#[cfg(all(feature = "seccomp", target_os = "linux"))]
{
let syscalls = func_with_syscalls.1.clone();
let seccomp_filter =
crate::seccomp::guest::get_seccomp_filter_for_host_function_worker_thread(
syscalls,
)?;
seccompiler::apply_filter(&seccomp_filter)?;
}

crate::metrics::maybe_time_and_emit_host_call(name, || func.call(args))
}

cfg_if::cfg_if! {
if #[cfg(all(feature = "seccomp", target_os = "linux"))] {
// Clone variables for the thread
let host_funcs_cloned = host_funcs.clone();
let name_cloned = name.to_string();
let args_cloned = args.clone();

// Create a new thread when seccomp is enabled on Linux
let join_handle = std::thread::Builder::new()
.name(format!("Host Function Worker Thread for: {:?}", name_cloned))
.spawn(move || {
// We have a `catch_unwind` here because, if a disallowed syscall is issued,
// we handle it by panicking. This is to avoid returning execution to the
// offending host function—for two reasons: (1) if a host function is issuing
// disallowed syscalls, it could be unsafe to return to, and (2) returning
// execution after trapping the disallowed syscall can lead to UB (e.g., try
// running a host function that attempts to sleep without `SYS_clock_nanosleep`,
// you'll block the syscall but panic in the aftermath).
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| call_func(&host_funcs_cloned, &name_cloned, args_cloned))) {
Ok(val) => val,
Err(err) => {
if let Some(crate::HyperlightError::DisallowedSyscall) = err.downcast_ref::<crate::HyperlightError>() {
return Err(crate::HyperlightError::DisallowedSyscall)
}

crate::log_then_return!("Host function {} panicked", name_cloned);
}
}
})?;

join_handle.join().map_err(|_| new_error!("Error joining thread executing host function"))?
} else {
// Directly call the function without creating a new thread
call_func(host_funcs, name, args)
}
// Create a new thread when seccomp is enabled on Linux
maybe_with_seccomp(name, extra_allowed_syscalls.as_deref(), || {
crate::metrics::maybe_time_and_emit_host_call(name, || function.call(args))
})
}
}

Expand All @@ -197,3 +147,55 @@ pub(super) fn default_writer_func(s: String) -> Result<i32> {
}
}
}

#[cfg(all(feature = "seccomp", target_os = "linux"))]
fn maybe_with_seccomp<T: Send>(
name: &str,
syscalls: Option<&[ExtraAllowedSyscall]>,
f: impl FnOnce() -> Result<T> + Send,
) -> Result<T> {
use crate::seccomp::guest::get_seccomp_filter_for_host_function_worker_thread;

// Use a scoped thread so that we can pass around references without having to clone them.
crossbeam::thread::scope(|s| {
s.builder()
.name(format!("Host Function Worker Thread for: {name:?}",))
.spawn(move |_| {
let seccomp_filter = get_seccomp_filter_for_host_function_worker_thread(syscalls)?;
seccompiler::apply_filter(&seccomp_filter)?;

// We have a `catch_unwind` here because, if a disallowed syscall is issued,
// we handle it by panicking. This is to avoid returning execution to the
// offending host function—for two reasons: (1) if a host function is issuing
// disallowed syscalls, it could be unsafe to return to, and (2) returning
// execution after trapping the disallowed syscall can lead to UB (e.g., try
// running a host function that attempts to sleep without `SYS_clock_nanosleep`,
// you'll block the syscall but panic in the aftermath).
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
Ok(val) => val,
Err(err) => {
if let Some(crate::HyperlightError::DisallowedSyscall) =
err.downcast_ref::<crate::HyperlightError>()
{
return Err(crate::HyperlightError::DisallowedSyscall);
}

crate::log_then_return!("Host function {} panicked", name);
}
}
})?
.join()
.map_err(|_| new_error!("Error joining thread executing host function"))?
})
.map_err(|_| new_error!("Error joining thread executing host function"))?
}

#[cfg(not(all(feature = "seccomp", target_os = "linux")))]
fn maybe_with_seccomp<T: Send>(
_name: &str,
_syscalls: Option<&[ExtraAllowedSyscall]>,
f: impl FnOnce() -> Result<T> + Send,
) -> Result<T> {
// No seccomp, just call the function
f()
}
6 changes: 3 additions & 3 deletions src/hyperlight_host/src/sandbox/initialized_multi_use.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use hyperlight_common::flatbuffer_wrappers::function_types::{
};
use tracing::{instrument, Span};

use super::host_funcs::HostFuncsWrapper;
use super::host_funcs::FunctionRegistry;
use super::{MemMgrWrapper, WrapperGetter};
use crate::func::call_ctx::MultiUseGuestCallContext;
use crate::func::guest_dispatch::call_function_on_guest;
Expand All @@ -41,7 +41,7 @@ use crate::Result;
/// in this case the state of the sandbox is not reset until the context is finished and the `MultiUseSandbox` is returned.
pub struct MultiUseSandbox {
// We need to keep a reference to the host functions, even if the compiler marks it as unused. The compiler cannot detect our dynamic usages of the host function in `HyperlightFunction::call`.
pub(super) _host_funcs: Arc<Mutex<HostFuncsWrapper>>,
pub(super) _host_funcs: Arc<Mutex<FunctionRegistry>>,
pub(crate) mem_mgr: MemMgrWrapper<HostSharedMemory>,
hv_handler: HypervisorHandler,
}
Expand Down Expand Up @@ -73,7 +73,7 @@ impl MultiUseSandbox {
/// (as a `From` implementation would be)
#[instrument(skip_all, parent = Span::current(), level = "Trace")]
pub(super) fn from_uninit(
host_funcs: Arc<Mutex<HostFuncsWrapper>>,
host_funcs: Arc<Mutex<FunctionRegistry>>,
mgr: MemMgrWrapper<HostSharedMemory>,
hv_handler: HypervisorHandler,
) -> MultiUseSandbox {
Expand Down
6 changes: 3 additions & 3 deletions src/hyperlight_host/src/sandbox/outb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use log::{Level, Record};
use tracing::{instrument, Span};
use tracing_log::format_trace;

use super::host_funcs::HostFuncsWrapper;
use super::host_funcs::FunctionRegistry;
use super::mem_mgr::MemMgrWrapper;
use crate::hypervisor::handlers::{OutBHandler, OutBHandlerFunction, OutBHandlerWrapper};
use crate::mem::mgr::SandboxMemoryManager;
Expand Down Expand Up @@ -97,7 +97,7 @@ pub(super) fn outb_log(mgr: &mut SandboxMemoryManager<HostSharedMemory>) -> Resu
#[instrument(err(Debug), skip_all, parent = Span::current(), level= "Trace")]
fn handle_outb_impl(
mem_mgr: &mut MemMgrWrapper<HostSharedMemory>,
host_funcs: Arc<Mutex<HostFuncsWrapper>>,
host_funcs: Arc<Mutex<FunctionRegistry>>,
port: u16,
data: Vec<u8>,
) -> Result<()> {
Expand Down Expand Up @@ -153,7 +153,7 @@ fn handle_outb_impl(
#[instrument(skip_all, parent = Span::current(), level= "Trace")]
pub(crate) fn outb_handler_wrapper(
mut mem_mgr_wrapper: MemMgrWrapper<HostSharedMemory>,
host_funcs_wrapper: Arc<Mutex<HostFuncsWrapper>>,
host_funcs_wrapper: Arc<Mutex<FunctionRegistry>>,
) -> OutBHandlerWrapper {
let outb_func: OutBHandlerFunction = Box::new(move |port, payload| {
handle_outb_impl(
Expand Down
6 changes: 3 additions & 3 deletions src/hyperlight_host/src/sandbox/uninitialized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use tracing::{instrument, Span};

#[cfg(gdb)]
use super::config::DebugInfo;
use super::host_funcs::{default_writer_func, HostFuncsWrapper};
use super::host_funcs::{default_writer_func, FunctionRegistry};
use super::mem_mgr::MemMgrWrapper;
use super::run_options::SandboxRunOptions;
use super::uninitialized_evolve::evolve_impl_multi_use;
Expand All @@ -48,7 +48,7 @@ use crate::{log_build_details, log_then_return, new_error, MultiUseSandbox, Resu
/// `UninitializedSandbox` into an initialized `Sandbox`.
pub struct UninitializedSandbox {
/// Registered host functions
pub(crate) host_funcs: Arc<Mutex<HostFuncsWrapper>>,
pub(crate) host_funcs: Arc<Mutex<FunctionRegistry>>,
/// The memory manager for the sandbox.
pub(crate) mgr: MemMgrWrapper<ExclusiveSharedMemory>,
pub(crate) run_inprocess: bool,
Expand Down Expand Up @@ -184,7 +184,7 @@ impl UninitializedSandbox {

mem_mgr_wrapper.write_memory_layout(run_inprocess)?;

let host_funcs = Arc::new(Mutex::new(HostFuncsWrapper::default()));
let host_funcs = Arc::new(Mutex::new(FunctionRegistry::default()));

let mut sandbox = Self {
host_funcs,
Expand Down
6 changes: 3 additions & 3 deletions src/hyperlight_host/src/sandbox/uninitialized_evolve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::mem::ptr::RawPtr;
use crate::mem::shared_mem::GuestSharedMemory;
#[cfg(gdb)]
use crate::sandbox::config::DebugInfo;
use crate::sandbox::host_funcs::HostFuncsWrapper;
use crate::sandbox::host_funcs::FunctionRegistry;
use crate::sandbox::mem_access::mem_access_handler_wrapper;
use crate::sandbox::outb::outb_handler_wrapper;
use crate::sandbox::{HostSharedMemory, MemMgrWrapper};
Expand All @@ -56,7 +56,7 @@ fn evolve_impl<TransformFunc, ResSandbox: Sandbox>(
) -> Result<ResSandbox>
where
TransformFunc: Fn(
Arc<Mutex<HostFuncsWrapper>>,
Arc<Mutex<FunctionRegistry>>,
MemMgrWrapper<HostSharedMemory>,
HypervisorHandler,
) -> Result<ResSandbox>,
Expand Down Expand Up @@ -105,7 +105,7 @@ pub(super) fn evolve_impl_multi_use(u_sbox: UninitializedSandbox) -> Result<Mult
fn hv_init(
hshm: &MemMgrWrapper<HostSharedMemory>,
gshm: SandboxMemoryManager<GuestSharedMemory>,
host_funcs: Arc<Mutex<HostFuncsWrapper>>,
host_funcs: Arc<Mutex<FunctionRegistry>>,
max_init_time: Duration,
max_exec_time: Duration,
max_wait_for_cancellation: Duration,
Expand Down
5 changes: 3 additions & 2 deletions src/hyperlight_host/src/seccomp/guest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,15 @@ fn syscalls_allowlist() -> Result<Vec<(i64, Vec<SeccompRule>)>> {
/// (e.g., `KVM_SET_USER_MEMORY_REGION`, `KVM_GET_API_VERSION`, `KVM_CREATE_VM`,
/// or `KVM_CREATE_VCPU`).
pub(crate) fn get_seccomp_filter_for_host_function_worker_thread(
extra_allowed_syscalls: Option<Vec<ExtraAllowedSyscall>>,
extra_allowed_syscalls: Option<&[ExtraAllowedSyscall]>,
) -> Result<BpfProgram> {
let mut allowed_syscalls = syscalls_allowlist()?;

if let Some(extra_allowed_syscalls) = extra_allowed_syscalls {
allowed_syscalls.extend(
extra_allowed_syscalls
.into_iter()
.iter()
.copied()
.map(|syscall| (syscall, vec![])),
);

Expand Down