Skip to content

Commit 1cec0ae

Browse files
committed
refactor: ditch slot
It turns out the `parking_lot` `arc_lock` future does the main thing we want from `Slot`, giving us `'static` lock guards that unlock the value on drop. The only missing functionality is the "stealing" we need to obtain ownership in order to commit the transaction. Rather than implementing this via `Option`, we add an additional state to `LazyTransaction` and handle it there. Ultimately this removes a lot of code, and makes the synchronisation mechanism even less exotic.
1 parent 84a49ae commit 1cec0ae

File tree

5 files changed

+35
-234
lines changed

5 files changed

+35
-234
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ bytes = "1"
3131
futures-core = "0.3"
3232
http = "1"
3333
http-body = "1"
34-
parking_lot = "0.12"
34+
parking_lot = { version = "0.12", features = ["arc_lock", "send_guard"] }
3535
sqlx = { version = "0.7", default-features = false }
3636
thiserror = "1"
3737
tower-layer = "0.3"

src/extension.rs

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,36 @@
1+
use std::sync::Arc;
2+
3+
use parking_lot::{lock_api::ArcMutexGuard, Mutex, RawMutex};
14
use sqlx::Transaction;
25

3-
use crate::{
4-
slot::{Lease, Slot},
5-
Error, Marker, State,
6-
};
6+
use crate::{Error, Marker, State};
77

88
/// The request extension.
99
pub(crate) struct Extension<DB: Marker> {
10-
slot: Slot<LazyTransaction<DB>>,
10+
slot: Arc<Mutex<LazyTransaction<DB>>>,
1111
}
1212

1313
impl<DB: Marker> Extension<DB> {
1414
pub(crate) fn new(state: State<DB>) -> Self {
15-
let slot = Slot::new(LazyTransaction::new(state));
15+
let slot = Arc::new(Mutex::new(LazyTransaction::new(state)));
1616
Self { slot }
1717
}
1818

19-
pub(crate) async fn acquire(&self) -> Result<Lease<LazyTransaction<DB>>, Error> {
20-
let mut tx = self.slot.lease().ok_or(Error::OverlappingExtractors)?;
19+
pub(crate) async fn acquire(
20+
&self,
21+
) -> Result<ArcMutexGuard<RawMutex, LazyTransaction<DB>>, Error> {
22+
let mut tx = self
23+
.slot
24+
.try_lock_arc()
25+
.ok_or(Error::OverlappingExtractors)?;
2126
tx.acquire().await?;
2227

2328
Ok(tx)
2429
}
2530

2631
pub(crate) async fn resolve(&self) -> Result<(), sqlx::Error> {
27-
if let Some(tx) = self.slot.lease() {
28-
tx.steal().resolve().await?;
32+
if let Some(mut tx) = self.slot.try_lock_arc() {
33+
tx.resolve().await?;
2934
}
3035
Ok(())
3136
}
@@ -49,6 +54,7 @@ enum LazyTransactionState<DB: Marker> {
4954
Acquired {
5055
tx: Transaction<'static, DB::Driver>,
5156
},
57+
Resolved,
5258
}
5359

5460
impl<DB: Marker> LazyTransaction<DB> {
@@ -62,6 +68,7 @@ impl<DB: Marker> LazyTransaction<DB> {
6268
panic!("BUG: exposed unacquired LazyTransaction")
6369
}
6470
LazyTransactionState::Acquired { tx } => tx,
71+
LazyTransactionState::Resolved => panic!("BUG: exposed resolved LazyTransaction"),
6572
}
6673
}
6774

@@ -71,33 +78,36 @@ impl<DB: Marker> LazyTransaction<DB> {
7178
panic!("BUG: exposed unacquired LazyTransaction")
7279
}
7380
LazyTransactionState::Acquired { tx } => tx,
81+
LazyTransactionState::Resolved => panic!("BUG: exposed resolved LazyTransaction"),
7482
}
7583
}
7684

77-
async fn acquire(&mut self) -> Result<(), sqlx::Error> {
85+
async fn acquire(&mut self) -> Result<(), Error> {
7886
match &self.0 {
7987
LazyTransactionState::Unacquired { state } => {
8088
let tx = state.transaction().await?;
8189
self.0 = LazyTransactionState::Acquired { tx };
8290
Ok(())
8391
}
8492
LazyTransactionState::Acquired { .. } => Ok(()),
93+
LazyTransactionState::Resolved => Err(Error::OverlappingExtractors),
8594
}
8695
}
8796

88-
pub(crate) async fn resolve(self) -> Result<(), sqlx::Error> {
89-
match self.0 {
90-
LazyTransactionState::Unacquired { .. } => Ok(()),
97+
pub(crate) async fn resolve(&mut self) -> Result<(), sqlx::Error> {
98+
match std::mem::replace(&mut self.0, LazyTransactionState::Resolved) {
99+
LazyTransactionState::Unacquired { .. } | LazyTransactionState::Resolved => Ok(()),
91100
LazyTransactionState::Acquired { tx } => tx.commit().await,
92101
}
93102
}
94103

95-
pub(crate) async fn commit(self) -> Result<(), sqlx::Error> {
96-
match self.0 {
104+
pub(crate) async fn commit(&mut self) -> Result<(), sqlx::Error> {
105+
match std::mem::replace(&mut self.0, LazyTransactionState::Resolved) {
97106
LazyTransactionState::Unacquired { .. } => {
98107
panic!("BUG: tried to commit unacquired transaction")
99108
}
100109
LazyTransactionState::Acquired { tx } => tx.commit().await,
110+
LazyTransactionState::Resolved => panic!("BUG: tried to commit resolved transaction"),
101111
}
102112
}
103113
}

src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ mod error;
8888
mod extension;
8989
mod layer;
9090
mod marker;
91-
mod slot;
9291
mod state;
9392
mod tx;
9493

src/slot.rs

Lines changed: 0 additions & 208 deletions
This file was deleted.

src/tx.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ use axum_core::{
88
};
99
use futures_core::{future::BoxFuture, stream::BoxStream};
1010
use http::request::Parts;
11+
use parking_lot::{lock_api::ArcMutexGuard, RawMutex};
1112

1213
use crate::{
1314
extension::{Extension, LazyTransaction},
14-
slot::Lease,
1515
Config, Error, Marker, State,
1616
};
1717

@@ -74,7 +74,7 @@ use crate::{
7474
/// }
7575
/// ```
7676
pub struct Tx<DB: Marker, E = Error> {
77-
tx: Lease<LazyTransaction<DB>>,
77+
tx: ArcMutexGuard<RawMutex, LazyTransaction<DB>>,
7878
_error: PhantomData<E>,
7979
}
8080

@@ -121,8 +121,8 @@ impl<DB: Marker, E> Tx<DB, E> {
121121
///
122122
/// **Note:** trying to use the `Tx` extractor again after calling `commit` will currently
123123
/// generate [`Error::OverlappingExtractors`] errors. This may change in future.
124-
pub async fn commit(self) -> Result<(), sqlx::Error> {
125-
self.tx.steal().commit().await
124+
pub async fn commit(mut self) -> Result<(), sqlx::Error> {
125+
self.tx.commit().await
126126
}
127127
}
128128

@@ -134,27 +134,27 @@ impl<DB: Marker, E> fmt::Debug for Tx<DB, E> {
134134

135135
impl<DB: Marker, E> AsRef<sqlx::Transaction<'static, DB::Driver>> for Tx<DB, E> {
136136
fn as_ref(&self) -> &sqlx::Transaction<'static, DB::Driver> {
137-
self.tx.as_ref().as_ref()
137+
self.tx.as_ref()
138138
}
139139
}
140140

141141
impl<DB: Marker, E> AsMut<sqlx::Transaction<'static, DB::Driver>> for Tx<DB, E> {
142142
fn as_mut(&mut self) -> &mut sqlx::Transaction<'static, DB::Driver> {
143-
self.tx.as_mut().as_mut()
143+
self.tx.as_mut()
144144
}
145145
}
146146

147147
impl<DB: Marker, E> std::ops::Deref for Tx<DB, E> {
148148
type Target = sqlx::Transaction<'static, DB::Driver>;
149149

150150
fn deref(&self) -> &Self::Target {
151-
self.tx.as_ref().as_ref()
151+
self.tx.as_ref()
152152
}
153153
}
154154

155155
impl<DB: Marker, E> std::ops::DerefMut for Tx<DB, E> {
156156
fn deref_mut(&mut self) -> &mut Self::Target {
157-
self.tx.as_mut().as_mut()
157+
self.tx.as_mut()
158158
}
159159
}
160160

0 commit comments

Comments
 (0)