@@ -28,19 +28,25 @@ use crate::{new_error, Result};
28
28
29
29
#[ derive( Default , Clone ) ]
30
30
/// A Wrapper around details of functions exposed by the Host
31
- pub struct HostFuncsWrapper {
32
- functions_map : HashMap < String , ( HyperlightFunction , Option < Vec < ExtraAllowedSyscall > > ) > ,
31
+ pub struct FunctionRegistry {
32
+ functions_map : HashMap < String , FunctionEntry > ,
33
33
}
34
34
35
- impl HostFuncsWrapper {
35
+ #[ derive( Clone ) ]
36
+ pub struct FunctionEntry {
37
+ pub function : HyperlightFunction ,
38
+ pub extra_allowed_syscalls : Option < Vec < ExtraAllowedSyscall > > ,
39
+ }
40
+
41
+ impl FunctionRegistry {
36
42
/// Register a host function with the sandbox.
37
43
#[ instrument( err( Debug ) , skip_all, parent = Span :: current( ) , level = "Trace" ) ]
38
44
pub ( crate ) fn register_host_function (
39
45
& mut self ,
40
46
name : String ,
41
47
func : HyperlightFunction ,
42
48
) -> Result < ( ) > {
43
- register_host_function_helper ( self , name, func, None )
49
+ self . register_host_function_helper ( name, func, None )
44
50
}
45
51
46
52
/// Register a host function with the sandbox, with a list of extra syscalls
@@ -53,7 +59,7 @@ impl HostFuncsWrapper {
53
59
func : HyperlightFunction ,
54
60
extra_allowed_syscalls : Vec < ExtraAllowedSyscall > ,
55
61
) -> Result < ( ) > {
56
- register_host_function_helper ( self , name, func, Some ( extra_allowed_syscalls) )
62
+ self . register_host_function_helper ( name, func, Some ( extra_allowed_syscalls) )
57
63
}
58
64
59
65
/// Assuming a host function called `"HostPrint"` exists, and takes a
@@ -63,11 +69,7 @@ impl HostFuncsWrapper {
63
69
/// and `Err` otherwise.
64
70
#[ instrument( err( Debug ) , skip_all, parent = Span :: current( ) , level = "Trace" ) ]
65
71
pub ( super ) fn host_print ( & mut self , msg : String ) -> Result < i32 > {
66
- let res = call_host_func_impl (
67
- & self . functions_map ,
68
- "HostPrint" ,
69
- vec ! [ ParameterValue :: String ( msg) ] ,
70
- ) ?;
72
+ let res = self . call_host_func_impl ( "HostPrint" , vec ! [ ParameterValue :: String ( msg) ] ) ?;
71
73
res. try_into ( )
72
74
. map_err ( |_| HostFunctionNotFound ( "HostPrint" . to_string ( ) ) )
73
75
}
@@ -84,97 +86,45 @@ impl HostFuncsWrapper {
84
86
name : & str ,
85
87
args : Vec < ParameterValue > ,
86
88
) -> Result < ReturnValue > {
87
- call_host_func_impl ( & self . functions_map , name, args)
89
+ self . call_host_func_impl ( name, args)
88
90
}
89
- }
90
-
91
- fn register_host_function_helper (
92
- self_ : & mut HostFuncsWrapper ,
93
- name : String ,
94
- func : HyperlightFunction ,
95
- extra_allowed_syscalls : Option < Vec < ExtraAllowedSyscall > > ,
96
- ) -> Result < ( ) > {
97
- if let Some ( _syscalls) = extra_allowed_syscalls {
98
- #[ cfg( all( feature = "seccomp" , target_os = "linux" ) ) ]
99
- self_. functions_map . insert ( name, ( func, Some ( _syscalls) ) ) ;
100
91
92
+ fn register_host_function_helper (
93
+ & mut self ,
94
+ name : String ,
95
+ function : HyperlightFunction ,
96
+ extra_allowed_syscalls : Option < Vec < ExtraAllowedSyscall > > ,
97
+ ) -> Result < ( ) > {
101
98
#[ cfg( not( all( feature = "seccomp" , target_os = "linux" ) ) ) ]
102
- return Err ( new_error ! (
103
- "Extra syscalls are only supported on Linux with seccomp"
104
- ) ) ;
105
- } else {
106
- self_. functions_map . insert ( name, ( func, None ) ) ;
99
+ if extra_allowed_syscalls. is_some ( ) {
100
+ return Err ( new_error ! (
101
+ "Extra syscalls are only supported on Linux with seccomp"
102
+ ) ) ;
103
+ }
104
+ self . functions_map . insert (
105
+ name,
106
+ FunctionEntry {
107
+ function,
108
+ extra_allowed_syscalls,
109
+ } ,
110
+ ) ;
111
+ Ok ( ( ) )
107
112
}
108
113
109
- Ok ( ( ) )
110
- }
111
-
112
- #[ instrument( err( Debug ) , skip_all, parent = Span :: current( ) , level = "Trace" ) ]
113
- fn call_host_func_impl (
114
- host_funcs : & HashMap < String , ( HyperlightFunction , Option < Vec < ExtraAllowedSyscall > > ) > ,
115
- name : & str ,
116
- args : Vec < ParameterValue > ,
117
- ) -> Result < ReturnValue > {
118
- // Inner function containing the common logic
119
- fn call_func (
120
- host_funcs : & HashMap < String , ( HyperlightFunction , Option < Vec < ExtraAllowedSyscall > > ) > ,
121
- name : & str ,
122
- args : Vec < ParameterValue > ,
123
- ) -> Result < ReturnValue > {
124
- let func_with_syscalls = host_funcs
114
+ #[ instrument( err( Debug ) , skip_all, parent = Span :: current( ) , level = "Trace" ) ]
115
+ fn call_host_func_impl ( & self , name : & str , args : Vec < ParameterValue > ) -> Result < ReturnValue > {
116
+ let FunctionEntry {
117
+ function,
118
+ extra_allowed_syscalls,
119
+ } = self
120
+ . functions_map
125
121
. get ( name)
126
122
. ok_or_else ( || HostFunctionNotFound ( name. to_string ( ) ) ) ?;
127
123
128
- let func = func_with_syscalls. 0 . clone ( ) ;
129
-
130
- #[ cfg( all( feature = "seccomp" , target_os = "linux" ) ) ]
131
- {
132
- let syscalls = func_with_syscalls. 1 . clone ( ) ;
133
- let seccomp_filter =
134
- crate :: seccomp:: guest:: get_seccomp_filter_for_host_function_worker_thread (
135
- syscalls,
136
- ) ?;
137
- seccompiler:: apply_filter ( & seccomp_filter) ?;
138
- }
139
-
140
- crate :: metrics:: maybe_time_and_emit_host_call ( name, || func. call ( args) )
141
- }
142
-
143
- cfg_if:: cfg_if! {
144
- if #[ cfg( all( feature = "seccomp" , target_os = "linux" ) ) ] {
145
- // Clone variables for the thread
146
- let host_funcs_cloned = host_funcs. clone( ) ;
147
- let name_cloned = name. to_string( ) ;
148
- let args_cloned = args. clone( ) ;
149
-
150
- // Create a new thread when seccomp is enabled on Linux
151
- let join_handle = std:: thread:: Builder :: new( )
152
- . name( format!( "Host Function Worker Thread for: {:?}" , name_cloned) )
153
- . spawn( move || {
154
- // We have a `catch_unwind` here because, if a disallowed syscall is issued,
155
- // we handle it by panicking. This is to avoid returning execution to the
156
- // offending host function—for two reasons: (1) if a host function is issuing
157
- // disallowed syscalls, it could be unsafe to return to, and (2) returning
158
- // execution after trapping the disallowed syscall can lead to UB (e.g., try
159
- // running a host function that attempts to sleep without `SYS_clock_nanosleep`,
160
- // you'll block the syscall but panic in the aftermath).
161
- match std:: panic:: catch_unwind( std:: panic:: AssertUnwindSafe ( || call_func( & host_funcs_cloned, & name_cloned, args_cloned) ) ) {
162
- Ok ( val) => val,
163
- Err ( err) => {
164
- if let Some ( crate :: HyperlightError :: DisallowedSyscall ) = err. downcast_ref:: <crate :: HyperlightError >( ) {
165
- return Err ( crate :: HyperlightError :: DisallowedSyscall )
166
- }
167
-
168
- crate :: log_then_return!( "Host function {} panicked" , name_cloned) ;
169
- }
170
- }
171
- } ) ?;
172
-
173
- join_handle. join( ) . map_err( |_| new_error!( "Error joining thread executing host function" ) ) ?
174
- } else {
175
- // Directly call the function without creating a new thread
176
- call_func( host_funcs, name, args)
177
- }
124
+ // Create a new thread when seccomp is enabled on Linux
125
+ maybe_with_seccomp ( name, extra_allowed_syscalls. as_deref ( ) , || {
126
+ crate :: metrics:: maybe_time_and_emit_host_call ( name, || function. call ( args) )
127
+ } )
178
128
}
179
129
}
180
130
@@ -197,3 +147,55 @@ pub(super) fn default_writer_func(s: String) -> Result<i32> {
197
147
}
198
148
}
199
149
}
150
+
151
+ #[ cfg( all( feature = "seccomp" , target_os = "linux" ) ) ]
152
+ fn maybe_with_seccomp < T : Send > (
153
+ name : & str ,
154
+ syscalls : Option < & [ ExtraAllowedSyscall ] > ,
155
+ f : impl FnOnce ( ) -> Result < T > + Send ,
156
+ ) -> Result < T > {
157
+ use crate :: seccomp:: guest:: get_seccomp_filter_for_host_function_worker_thread;
158
+
159
+ // Use a scoped thread so that we can pass around references without having to clone them.
160
+ crossbeam:: thread:: scope ( |s| {
161
+ s. builder ( )
162
+ . name ( format ! ( "Host Function Worker Thread for: {name:?}" , ) )
163
+ . spawn ( move |_| {
164
+ let seccomp_filter = get_seccomp_filter_for_host_function_worker_thread ( syscalls) ?;
165
+ seccompiler:: apply_filter ( & seccomp_filter) ?;
166
+
167
+ // We have a `catch_unwind` here because, if a disallowed syscall is issued,
168
+ // we handle it by panicking. This is to avoid returning execution to the
169
+ // offending host function—for two reasons: (1) if a host function is issuing
170
+ // disallowed syscalls, it could be unsafe to return to, and (2) returning
171
+ // execution after trapping the disallowed syscall can lead to UB (e.g., try
172
+ // running a host function that attempts to sleep without `SYS_clock_nanosleep`,
173
+ // you'll block the syscall but panic in the aftermath).
174
+ match std:: panic:: catch_unwind ( std:: panic:: AssertUnwindSafe ( f) ) {
175
+ Ok ( val) => val,
176
+ Err ( err) => {
177
+ if let Some ( crate :: HyperlightError :: DisallowedSyscall ) =
178
+ err. downcast_ref :: < crate :: HyperlightError > ( )
179
+ {
180
+ return Err ( crate :: HyperlightError :: DisallowedSyscall ) ;
181
+ }
182
+
183
+ crate :: log_then_return!( "Host function {} panicked" , name) ;
184
+ }
185
+ }
186
+ } ) ?
187
+ . join ( )
188
+ . map_err ( |_| new_error ! ( "Error joining thread executing host function" ) ) ?
189
+ } )
190
+ . map_err ( |_| new_error ! ( "Error joining thread executing host function" ) ) ?
191
+ }
192
+
193
+ #[ cfg( not( all( feature = "seccomp" , target_os = "linux" ) ) ) ]
194
+ fn maybe_with_seccomp < T : Send > (
195
+ _name : & str ,
196
+ _syscalls : Option < & [ ExtraAllowedSyscall ] > ,
197
+ f : impl FnOnce ( ) -> Result < T > + Send ,
198
+ ) -> Result < T > {
199
+ // No seccomp, just call the function
200
+ f ( )
201
+ }
0 commit comments