Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ license = "Apache-2.0/MIT"
readme = "README.md"

[features]
default = ["std"]
macro = ["pollster-macro"]
std = []

[dependencies]
pollster-macro = { version = "0.1", path = "macro", optional = true }
Expand Down
235 changes: 164 additions & 71 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
#![doc = include_str!("../README.md")]

use std::{
#![cfg_attr(not(feature = "std"), no_std)]

use core::{
future::Future,
sync::{Arc, Condvar, Mutex},
task::{Context, Poll, Wake, Waker},
mem,
task::{Context, Poll, RawWaker, RawWakerVTable, Waker},
};

#[cfg(feature = "std")]
use std::thread::{self, Thread};

#[cfg(feature = "macro")]
pub use pollster_macro::{main, test};

Expand All @@ -23,111 +28,199 @@ pub trait FutureExt: Future {
/// let result = my_fut.block_on();
/// ```
fn block_on(self) -> Self::Output where Self: Sized { block_on(self) }

/// Block the thread until the future is ready with custom thread parking implementation.
///
/// This allows one to use custom thread parking mechanisms in `no_std` environments.
///
/// # Example
///
/// ```
/// use pollster::FutureExt as _;
/// use std::thread::Thread;
///
/// let my_fut = async {};
///
/// let result = my_fut.block_on_t::<Thread>();
/// ```
fn block_on_t<T: Parkable>(self) -> Self::Output where Self: Sized { block_on_t::<T, Self>(self) }
}

impl<F: Future> FutureExt for F {}

enum SignalState {
Empty,
Waiting,
Notified,
/// Parkable handle.
///
/// This handle allows a thread to potentially be efficiently blocked. This is used in the polling
/// implementation to wait for wakeups.
///
/// The interface models that of `std::thread`, and many functions, such as
/// [`current`](Parkable::current), [`park`](Parkable::park), and [`unpark`](Parkable::unpark)
/// map to `std::thread` equivalents.
pub trait Parkable: Sized + Clone {
/// Get handle to current thread.
fn current() -> Self;

/// Park the current thread.
fn park();

/// Unpark specified thread.
fn unpark(&self);

/// Convert self into opaque pointer.
///
/// This requires `Self` to either be layout compatible with `*const ()` or heap allocated upon
/// switch.
fn into_opaque(self) -> *const ();

/// Convert opaque pointer into `Self`.
///
/// # Safety
///
/// This function is safe if the `data` argument is a valid park handle created by
/// `Self::into_opaque`.
unsafe fn from_opaque(data: *const ()) -> Self;

/// Create a waker out of `self`
///
/// This function will clone self and build a `Waker` object.
fn waker(&self) -> Waker {
let data = self.clone().into_opaque();
// SAFETY: `RawWaker` created by `raw_waker` builds a waker object out of the raw data and
// vtable methods of this type which we assume are correct.
unsafe {
Waker::from_raw(raw_waker::<Self>(data))
}
}
}

struct Signal {
state: Mutex<SignalState>,
cond: Condvar,
#[cfg(feature = "std")]
pub type DefaultHandle = Thread;
#[cfg(not(feature = "std"))]
pub type DefaultHandle = *const ();

fn raw_waker<T: Parkable>(data: *const ()) -> RawWaker {
RawWaker::new(
data,
&RawWakerVTable::new(
clone_waker::<T>,
wake::<T>,
wake_by_ref::<T>,
drop_waker::<T>,
),
)
}

impl Signal {
fn new() -> Self {
Self {
state: Mutex::new(SignalState::Empty),
cond: Condvar::new(),
}
unsafe fn clone_waker<T: Parkable>(data: *const ()) -> RawWaker {
let waker = T::from_opaque(data);
mem::forget(waker.clone());
mem::forget(waker);
raw_waker::<T>(data)
}

unsafe fn wake<T: Parkable>(data: *const ()) {
let waker = T::from_opaque(data);
waker.unpark();
}

unsafe fn wake_by_ref<T: Parkable>(data: *const ()) {
let waker = T::from_opaque(data);
waker.unpark();
mem::forget(waker);
}

unsafe fn drop_waker<T: Parkable>(data: *const ()) {
let _ = T::from_opaque(data);
}

#[cfg(feature = "std")]
impl Parkable for Thread {
fn current() -> Self {
thread::current()
}

fn wait(&self) {
let mut state = self.state.lock().unwrap();
match *state {
SignalState::Notified => {
// Notify() was called before we got here, consume it here without waiting and return immediately.
*state = SignalState::Empty;
return;
}
// This should not be possible because our signal is created within a function and never handed out to any
// other threads. If this is the case, we have a serious problem so we panic immediately to avoid anything
// more problematic happening.
SignalState::Waiting => {
unreachable!("Multiple threads waiting on the same signal: Open a bug report!");
}
SignalState::Empty => {
// Nothing has happened yet, and we're the only thread waiting (as should be the case!). Set the state
// accordingly and begin polling the condvar in a loop until it's no longer telling us to wait. The
// loop prevents incorrect spurious wakeups.
*state = SignalState::Waiting;
while let SignalState::Waiting = *state {
state = self.cond.wait(state).unwrap();
}
}
}
fn park() {
thread::park();
}

fn notify(&self) {
let mut state = self.state.lock().unwrap();
match *state {
// The signal was already notified, no need to do anything because the thread will be waking up anyway
SignalState::Notified => {}
// The signal wasn't notified but a thread isn't waiting on it, so we can avoid doing unnecessary work by
// skipping the condvar and leaving behind a message telling the thread that a notification has already
// occurred should it come along in the future.
SignalState::Empty => *state = SignalState::Notified,
// The signal wasn't notified and there's a waiting thread. Reset the signal so it can be wait()'ed on again
// and wake up the thread. Because there should only be a single thread waiting, `notify_all` would also be
// valid.
SignalState::Waiting => {
*state = SignalState::Empty;
self.cond.notify_one();
}
}
fn unpark(&self) {
Thread::unpark(self);
}

fn into_opaque(self) -> *const () {
// SAFETY: `Thread` internal layout is an Arc to inner type, which is represented as a
// single pointer. The only thing we do with the pointer is transmute it back to
// ThreadWaker in the waker functions. If for whatever reason Thread layout will change to
// contain multiple fields, this will still be safe, because the compiler will simply
// refuse to compile the program.
unsafe { mem::transmute::<_, *const ()>(self) }
}

unsafe fn from_opaque(data: *const ()) -> Self {
mem::transmute(data)
}
}

impl Wake for Signal {
fn wake(self: Arc<Self>) {
self.notify();
impl Parkable for *const () {
fn current() -> Self {
core::ptr::null()
}

fn park() {
core::hint::spin_loop()
}

fn unpark(&self) {}

fn into_opaque(self) -> *const () {
self
}

unsafe fn from_opaque(data: *const ()) -> Self {
data
}
}

/// Block the thread until the future is ready.
/// Block the thread until the future is ready with custom parking implementation.
///
/// This allows one to use custom thread parking mechanisms in `no_std` environments.
///
/// # Example
///
/// ```
/// use std::thread::Thread;
///
/// let my_fut = async {};
/// let result = pollster::block_on(my_fut);
/// let result = pollster::block_on_t::<Thread, _>(my_fut);
/// ```
pub fn block_on<F: Future>(mut fut: F) -> F::Output {
pub fn block_on_t<T: Parkable, F: Future>(mut fut: F) -> F::Output {
// Pin the future so that it can be polled.
// SAFETY: We shadow `fut` so that it cannot be used again. The future is now pinned to the stack and will not be
// moved until the end of this scope. This is, incidentally, exactly what the `pin_mut!` macro from `pin_utils`
// does.
let mut fut = unsafe { std::pin::Pin::new_unchecked(&mut fut) };

// Signal used to wake up the thread for polling as the future moves to completion. We need to use an `Arc`
// because, although the lifetime of `fut` is limited to this function, the underlying IO abstraction might keep
// the signal alive for far longer. `Arc` is a thread-safe way to allow this to happen.
// TODO: Investigate ways to reuse this `Arc<Signal>`... perhaps via a `static`?
let signal = Arc::new(Signal::new());
let handle = T::current();

// Create a context that will be passed to the future.
let waker = Waker::from(Arc::clone(&signal));
let waker: Waker = handle.waker();
let mut context = Context::from_waker(&waker);

// Poll the future to completion
loop {
match fut.as_mut().poll(&mut context) {
Poll::Pending => signal.wait(),
Poll::Pending => T::park(),
Poll::Ready(item) => break item,
}
}
}

/// Block the thread until the future is ready.
///
/// # Example
///
/// ```
/// let my_fut = async {};
/// let result = pollster::block_on(my_fut);
/// ```
pub fn block_on<F: Future>(fut: F) -> F::Output {
return block_on_t::<DefaultHandle, _>(fut);
}