Skip to content

Commit 95b558e

Browse files
committed
Synchronize cancel queue with usercall queue
1 parent 556765d commit 95b558e

File tree

7 files changed

+188
-44
lines changed

7 files changed

+188
-44
lines changed

enclave-runner/src/usercalls/mod.rs

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use tokio::sync::broadcast;
3333
use tokio::sync::mpsc as async_mpsc;
3434

3535
use fortanix_sgx_abi::*;
36-
use ipc_queue::{self, DescriptorGuard, Identified, QueueEvent};
36+
use ipc_queue::{self, DescriptorGuard, Identified, QueueEvent, WritePosition};
3737
use sgxs::loader::Tcs as SgxsTcs;
3838

3939
use crate::loader::{EnclavePanic, ErasedTcs};
@@ -636,26 +636,22 @@ impl Work {
636636
enum UsercallEvent {
637637
Started(u64, tokio::sync::oneshot::Sender<()>),
638638
Finished(u64),
639-
Cancelled(u64, Instant),
640-
}
641-
642-
fn ignore_cancel_impl(usercall_nr: u64) -> bool {
643-
usercall_nr != UsercallList::read as u64 &&
644-
usercall_nr != UsercallList::read_alloc as u64 &&
645-
usercall_nr != UsercallList::write as u64 &&
646-
usercall_nr != UsercallList::accept_stream as u64 &&
647-
usercall_nr != UsercallList::connect_stream as u64 &&
648-
usercall_nr != UsercallList::wait as u64
639+
Cancelled(u64, WritePosition),
649640
}
650641

651642
trait IgnoreCancel {
652643
fn ignore_cancel(&self) -> bool;
653644
}
645+
654646
impl IgnoreCancel for Identified<Usercall> {
655-
fn ignore_cancel(&self) -> bool { ignore_cancel_impl(self.data.0) }
656-
}
657-
impl IgnoreCancel for Identified<Cancel> {
658-
fn ignore_cancel(&self) -> bool { ignore_cancel_impl(self.data.usercall_nr) }
647+
fn ignore_cancel(&self) -> bool {
648+
self.data.0 != UsercallList::read as u64 &&
649+
self.data.0 != UsercallList::read_alloc as u64 &&
650+
self.data.0 != UsercallList::write as u64 &&
651+
self.data.0 != UsercallList::accept_stream as u64 &&
652+
self.data.0 != UsercallList::connect_stream as u64 &&
653+
self.data.0 != UsercallList::wait as u64
654+
}
659655
}
660656

661657
impl EnclaveState {
@@ -892,6 +888,8 @@ impl EnclaveState {
892888
*enclave_clone.fifo_guards.lock().await = Some(fifo_guards);
893889
*enclave_clone.return_queue_tx.lock().await = Some(return_queue_tx);
894890

891+
let usercall_queue_monitor = usercall_queue_rx.position_monitor();
892+
895893
tokio::task::spawn_local(async move {
896894
while let Ok(usercall) = usercall_queue_rx.recv().await {
897895
let _ = io_queue_send.send(UsercallSendData::Async(usercall));
@@ -900,37 +898,32 @@ impl EnclaveState {
900898

901899
let (usercall_event_tx, mut usercall_event_rx) = async_mpsc::unbounded_channel();
902900
let usercall_event_tx_clone = usercall_event_tx.clone();
901+
let usercall_queue_monitor_clone = usercall_queue_monitor.clone();
903902
tokio::task::spawn_local(async move {
904903
while let Ok(c) = cancel_queue_rx.recv().await {
905-
if !c.ignore_cancel() {
906-
let _ = usercall_event_tx_clone.send(UsercallEvent::Cancelled(c.id, Instant::now()));
907-
}
904+
let write_position = usercall_queue_monitor_clone.write_position();
905+
let _ = usercall_event_tx_clone.send(UsercallEvent::Cancelled(c.id, write_position));
908906
}
909907
});
910908

911909
tokio::task::spawn_local(async move {
912910
let mut notifiers = HashMap::new();
913-
let mut cancels: HashMap<u64, Instant> = HashMap::new();
914-
// This should be greater than the amount of time it takes for the enclave runner
915-
// to start executing a usercall after the enclave sends it on the usercall_queue.
916-
const CANCEL_EXPIRY: Duration = Duration::from_millis(100);
911+
let mut cancels: HashMap<u64, WritePosition> = HashMap::new();
917912
loop {
918913
match usercall_event_rx.recv().await.expect("usercall_event channel closed unexpectedly") {
919914
UsercallEvent::Started(id, notifier) => match cancels.remove(&id) {
920-
Some(t) if t.elapsed() < CANCEL_EXPIRY => { let _ = notifier.send(()); },
915+
Some(_) => { let _ = notifier.send(()); },
921916
_ => { notifiers.insert(id, notifier); },
922917
},
923918
UsercallEvent::Finished(id) => { notifiers.remove(&id); },
924-
UsercallEvent::Cancelled(id, t) => if t.elapsed() < CANCEL_EXPIRY {
925-
match notifiers.remove(&id) {
926-
Some(notifier) => { let _ = notifier.send(()); },
927-
None => { cancels.insert(id, t); },
928-
}
919+
UsercallEvent::Cancelled(id, wp) => match notifiers.remove(&id) {
920+
Some(notifier) => { let _ = notifier.send(()); },
921+
None => { cancels.insert(id, wp); },
929922
},
930923
}
931-
// cleanup expired cancels
932-
let now = Instant::now();
933-
cancels.retain(|_id, &mut t| now - t < CANCEL_EXPIRY);
924+
// cleanup old cancels
925+
let read_position = usercall_queue_monitor.read_position();
926+
cancels.retain(|_id, wp| !read_position.is_past(wp));
934927
}
935928
});
936929

fortanix-sgx-abi/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -718,8 +718,8 @@ pub mod async {
718718
#[derive(Copy, Clone, Default)]
719719
#[cfg_attr(feature = "rustc-dep-of-std", unstable(feature = "sgx_platform", issue = "56975"))]
720720
pub struct Cancel {
721-
/// This must be the same value as `Usercall.0`.
722-
pub usercall_nr: u64,
721+
/// Reserved for future use.
722+
pub reserved: u64,
723723
}
724724

725725
/// A circular buffer used as a FIFO queue with atomic reads and writes.

ipc-queue/src/fifo.rs

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
use std::cell::UnsafeCell;
88
use std::mem;
9-
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
9+
use std::sync::atomic::{AtomicU32, AtomicU64, AtomicUsize, Ordering};
1010
use std::sync::Arc;
1111

1212
use fortanix_sgx_abi::{FifoDescriptor, WithId};
@@ -33,7 +33,7 @@ where
3333
let arc = Arc::new(FifoBuffer::new(len));
3434
let inner = Fifo::from_arc(arc);
3535
let tx = AsyncSender { inner: inner.clone(), synchronizer: s.clone() };
36-
let rx = AsyncReceiver { inner, synchronizer: s };
36+
let rx = AsyncReceiver { inner, synchronizer: s, read_epoch: Arc::new(AtomicU32::new(0)) };
3737
(tx, rx)
3838
}
3939

@@ -87,6 +87,12 @@ impl<T> Clone for Fifo<T> {
8787
}
8888
}
8989

90+
impl<T> Fifo<T> {
91+
pub(crate) fn current_offsets(&self, ordering: Ordering) -> Offsets {
92+
Offsets::new(self.offsets.load(ordering), self.data.len() as u32)
93+
}
94+
}
95+
9096
impl<T: Transmittable> Fifo<T> {
9197
pub(crate) unsafe fn from_descriptor(descriptor: FifoDescriptor<T>) -> Self {
9298
assert!(
@@ -152,7 +158,7 @@ impl<T: Transmittable> Fifo<T> {
152158
pub(crate) fn try_send_impl(&self, val: Identified<T>) -> Result</*wake up reader:*/ bool, TrySendError> {
153159
let (new, was_empty) = loop {
154160
// 1. Load the current offsets.
155-
let current = Offsets::new(self.offsets.load(Ordering::SeqCst), self.data.len() as u32);
161+
let current = self.current_offsets(Ordering::SeqCst);
156162
let was_empty = current.is_empty();
157163

158164
// 2. If the queue is full, wait, then go to step 1.
@@ -179,9 +185,9 @@ impl<T: Transmittable> Fifo<T> {
179185
Ok(was_empty)
180186
}
181187

182-
pub(crate) fn try_recv_impl(&self) -> Result<(Identified<T>, /*wake up writer:*/ bool), TryRecvError> {
188+
pub(crate) fn try_recv_impl(&self) -> Result<(Identified<T>, /*wake up writer:*/ bool, /*read offset wrapped around:*/bool), TryRecvError> {
183189
// 1. Load the current offsets.
184-
let current = Offsets::new(self.offsets.load(Ordering::SeqCst), self.data.len() as u32);
190+
let current = self.current_offsets(Ordering::SeqCst);
185191

186192
// 2. If the queue is empty, wait, then go to step 1.
187193
if current.is_empty() {
@@ -216,7 +222,7 @@ impl<T: Transmittable> Fifo<T> {
216222

217223
// 8. If the queue was full before step 7, signal the writer to wake up.
218224
let was_full = Offsets::new(before, self.data.len() as u32).is_full();
219-
Ok((val, was_full))
225+
Ok((val, was_full, new.read_offset() == 0))
220226
}
221227
}
222228

@@ -282,6 +288,14 @@ impl Offsets {
282288
..*self
283289
}
284290
}
291+
292+
pub(crate) fn read_high_bit(&self) -> bool {
293+
self.read & self.len == self.len
294+
}
295+
296+
pub(crate) fn write_high_bit(&self) -> bool {
297+
self.write & self.len == self.len
298+
}
285299
}
286300

287301
#[cfg(test)]
@@ -308,7 +322,7 @@ mod tests {
308322
}
309323

310324
for i in 1..=7 {
311-
let (v, wake) = inner.try_recv_impl().unwrap();
325+
let (v, wake, _) = inner.try_recv_impl().unwrap();
312326
assert!(!wake);
313327
assert_eq!(v.id, i);
314328
assert_eq!(v.data.0, i);
@@ -327,7 +341,7 @@ mod tests {
327341
assert!(inner.try_send_impl(Identified { id: 9, data: TestValue(9) }).is_err());
328342

329343
for i in 1..=8 {
330-
let (v, wake) = inner.try_recv_impl().unwrap();
344+
let (v, wake, _) = inner.try_recv_impl().unwrap();
331345
assert!(if i == 1 { wake } else { !wake });
332346
assert_eq!(v.id, i);
333347
assert_eq!(v.data.0, i);

ipc-queue/src/interface_async.rs

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
66

77
use super::*;
8+
use std::sync::atomic::Ordering;
89

910
unsafe impl<T: Send, S: Send> Send for AsyncSender<T, S> {}
1011
unsafe impl<T: Send, S: Sync> Sync for AsyncSender<T, S> {}
@@ -52,10 +53,13 @@ impl<T: Transmittable, S: AsyncSynchronizer> AsyncReceiver<T, S> {
5253
pub async fn recv(&self) -> Result<Identified<T>, RecvError> {
5354
loop {
5455
match self.inner.try_recv_impl() {
55-
Ok((val, wake_sender)) => {
56+
Ok((val, wake_sender, read_wrapped_around)) => {
5657
if wake_sender {
5758
self.synchronizer.notify(QueueEvent::NotFull);
5859
}
60+
if read_wrapped_around {
61+
self.read_epoch.fetch_add(1, Ordering::Relaxed);
62+
}
5963
return Ok(val);
6064
}
6165
Err(TryRecvError::QueueEmpty) => {
@@ -68,6 +72,13 @@ impl<T: Transmittable, S: AsyncSynchronizer> AsyncReceiver<T, S> {
6872
}
6973
}
7074

75+
pub fn position_monitor(&self) -> PositionMonitor<T> {
76+
PositionMonitor {
77+
read_epoch: self.read_epoch.clone(),
78+
fifo: self.inner.clone(),
79+
}
80+
}
81+
7182
/// Consumes `self` and returns a DescriptorGuard.
7283
/// The returned guard can be used to make `FifoDescriptor`s that remain
7384
/// valid as long as the guard is not dropped.
@@ -153,6 +164,65 @@ mod tests {
153164
do_multi_sender(1024, 30, 100).await;
154165
}
155166

167+
#[tokio::test]
168+
async fn positions() {
169+
const LEN: usize = 16;
170+
let s = TestAsyncSynchronizer::new();
171+
let (tx, rx) = bounded_async(LEN, s);
172+
let monitor = rx.position_monitor();
173+
let mut id = 1;
174+
175+
let p0 = monitor.write_position();
176+
tx.send(Identified { id, data: TestValue(1) }).await.unwrap();
177+
let p1 = monitor.write_position();
178+
tx.send(Identified { id: id + 1, data: TestValue(2) }).await.unwrap();
179+
let p2 = monitor.write_position();
180+
tx.send(Identified { id: id + 2, data: TestValue(3) }).await.unwrap();
181+
let p3 = monitor.write_position();
182+
id += 3;
183+
assert!(monitor.read_position().is_past(&p0) == false);
184+
assert!(monitor.read_position().is_past(&p1) == false);
185+
assert!(monitor.read_position().is_past(&p2) == false);
186+
assert!(monitor.read_position().is_past(&p3) == false);
187+
188+
rx.recv().await.unwrap();
189+
assert!(monitor.read_position().is_past(&p0) == true);
190+
assert!(monitor.read_position().is_past(&p1) == false);
191+
assert!(monitor.read_position().is_past(&p2) == false);
192+
assert!(monitor.read_position().is_past(&p3) == false);
193+
194+
rx.recv().await.unwrap();
195+
assert!(monitor.read_position().is_past(&p0) == true);
196+
assert!(monitor.read_position().is_past(&p1) == true);
197+
assert!(monitor.read_position().is_past(&p2) == false);
198+
assert!(monitor.read_position().is_past(&p3) == false);
199+
200+
rx.recv().await.unwrap();
201+
assert!(monitor.read_position().is_past(&p0) == true);
202+
assert!(monitor.read_position().is_past(&p1) == true);
203+
assert!(monitor.read_position().is_past(&p2) == true);
204+
assert!(monitor.read_position().is_past(&p3) == false);
205+
206+
for i in 0..1000 {
207+
let n = 1 + (i % LEN);
208+
let p4 = monitor.write_position();
209+
for _ in 0..n {
210+
tx.send(Identified { id, data: TestValue(id) }).await.unwrap();
211+
id += 1;
212+
}
213+
let p5 = monitor.write_position();
214+
for _ in 0..n {
215+
rx.recv().await.unwrap();
216+
assert!(monitor.read_position().is_past(&p0) == true);
217+
assert!(monitor.read_position().is_past(&p1) == true);
218+
assert!(monitor.read_position().is_past(&p2) == true);
219+
assert!(monitor.read_position().is_past(&p3) == true);
220+
assert!(monitor.read_position().is_past(&p4) == true);
221+
assert!(monitor.read_position().is_past(&p5) == false);
222+
}
223+
}
224+
}
225+
156226
struct Subscription<T> {
157227
tx: broadcast::Sender<T>,
158228
rx: Mutex<broadcast::Receiver<T>>,

ipc-queue/src/interface_sync.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ impl<T: Transmittable, S: Synchronizer> Receiver<T, S> {
112112
}
113113

114114
pub fn try_recv(&self) -> Result<Identified<T>, TryRecvError> {
115-
self.inner.try_recv_impl().map(|(val, wake_sender)| {
115+
self.inner.try_recv_impl().map(|(val, wake_sender, _)| {
116116
if wake_sender {
117117
self.synchronizer.notify(QueueEvent::NotFull);
118118
}
@@ -127,7 +127,7 @@ impl<T: Transmittable, S: Synchronizer> Receiver<T, S> {
127127
pub fn recv(&self) -> Result<Identified<T>, RecvError> {
128128
loop {
129129
match self.inner.try_recv_impl() {
130-
Ok((val, wake_sender)) => {
130+
Ok((val, wake_sender, _)) => {
131131
if wake_sender {
132132
self.synchronizer.notify(QueueEvent::NotFull);
133133
}

ipc-queue/src/lib.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use std::future::Future;
1010
#[cfg(target_env = "sgx")]
1111
use std::os::fortanix_sgx::usercalls::alloc::UserSafeSized;
1212
use std::pin::Pin;
13+
use std::sync::atomic::AtomicU32;
1314
use std::sync::Arc;
1415

1516
use fortanix_sgx_abi::FifoDescriptor;
@@ -19,6 +20,7 @@ use self::fifo::{Fifo, FifoBuffer};
1920
mod fifo;
2021
mod interface_sync;
2122
mod interface_async;
23+
mod position;
2224
#[cfg(test)]
2325
mod test_support;
2426

@@ -123,6 +125,7 @@ pub struct AsyncSender<T: 'static, S> {
123125
pub struct AsyncReceiver<T: 'static, S> {
124126
inner: Fifo<T>,
125127
synchronizer: S,
128+
read_epoch: Arc<AtomicU32>,
126129
}
127130

128131
/// `DescriptorGuard<T>` can produce a `FifoDescriptor<T>` that is guaranteed
@@ -137,3 +140,19 @@ impl<T> DescriptorGuard<T> {
137140
self.descriptor
138141
}
139142
}
143+
144+
/// `PositionMonitor<T>` can be used to record the current read/write positions
145+
/// of a queue. Even though a queue is comprised of a limited number of slots
146+
/// arranged as a ring buffer, we can assign a position to each value written/
147+
/// read to/from the queue. This is useful in case we want to know whether or
148+
/// not a particular value written to the queue has been read.
149+
pub struct PositionMonitor<T: 'static> {
150+
read_epoch: Arc<AtomicU32>,
151+
fifo: Fifo<T>,
152+
}
153+
154+
/// A read position in a queue.
155+
pub struct ReadPosition(u64);
156+
157+
/// A write position in a queue.
158+
pub struct WritePosition(u64);

0 commit comments

Comments
 (0)