Skip to content

Commit 1940658

Browse files
authored
Merge pull request #27 from digital-society-coop/state
Simplify API and use "state" to improve type-safety
2 parents fcba66c + e896b5a commit 1940658

File tree

6 files changed

+332
-130
lines changed

6 files changed

+332
-130
lines changed

examples/example.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use std::error::Error;
55
use axum::{response::IntoResponse, routing::get, Json};
66
use http::StatusCode;
77

8-
// OPTIONAL: use a type alias to avoid repeating your database type
8+
// Recommended: use a type alias to avoid repeating your database type
99
type Tx = axum_sqlx_tx::Tx<sqlx::Sqlite>;
1010

1111
#[tokio::main]
@@ -19,11 +19,15 @@ async fn main() -> Result<(), Box<dyn Error>> {
1919
.execute(&pool)
2020
.await?;
2121

22+
let (state, layer) = Tx::setup(pool);
23+
2224
// Standard axum app setup
2325
let app = axum::Router::new()
2426
.route("/numbers", get(list_numbers).post(generate_number))
2527
// Apply the Tx middleware
26-
.layer(axum_sqlx_tx::Layer::new(pool.clone()));
28+
.layer(layer)
29+
// Add the Tx state
30+
.with_state(state);
2731

2832
let server = axum::Server::bind(&([0, 0, 0, 0], 0).into()).serve(app.into_make_service());
2933

src/layer.rs

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use bytes::Bytes;
77
use futures_core::future::BoxFuture;
88
use http_body::{combinators::UnsyncBoxBody, Body};
99

10-
use crate::{tx::TxSlot, Error};
10+
use crate::{tx::TxSlot, State};
1111

1212
/// A [`tower_layer::Layer`] that enables the [`Tx`] extractor.
1313
///
@@ -20,34 +20,19 @@ use crate::{tx::TxSlot, Error};
2020
///
2121
/// [`Tx`]: crate::Tx
2222
/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
23-
pub struct Layer<DB: sqlx::Database, E = Error> {
24-
pool: sqlx::Pool<DB>,
23+
pub struct Layer<DB: sqlx::Database, E> {
24+
state: State<DB>,
2525
_error: PhantomData<E>,
2626
}
2727

28-
impl<DB: sqlx::Database> Layer<DB> {
29-
/// Construct a new layer with the given `pool`.
30-
///
31-
/// A connection will be obtained from the pool the first time a [`Tx`](crate::Tx) is extracted
32-
/// from a request.
33-
///
34-
/// If you want to access the pool outside of a transaction, you should add it also with
35-
/// [`axum::Extension`].
36-
///
37-
/// To use a different type than [`Error`] to convert commit errors into responses, see
38-
/// [`new_with_error`](Self::new_with_error).
39-
///
40-
/// [`axum::Extension`]: https://docs.rs/axum/latest/axum/extract/struct.Extension.html
41-
pub fn new(pool: sqlx::Pool<DB>) -> Self {
42-
Self::new_with_error(pool)
43-
}
44-
45-
/// Construct a new layer with a specific error type.
46-
///
47-
/// See [`Layer::new`] for more information.
48-
pub fn new_with_error<E>(pool: sqlx::Pool<DB>) -> Layer<DB, E> {
49-
Layer {
50-
pool,
28+
impl<DB: sqlx::Database, E> Layer<DB, E>
29+
where
30+
E: IntoResponse,
31+
sqlx::Error: Into<E>,
32+
{
33+
pub(crate) fn new(state: State<DB>) -> Self {
34+
Self {
35+
state,
5136
_error: PhantomData,
5237
}
5338
}
@@ -56,18 +41,22 @@ impl<DB: sqlx::Database> Layer<DB> {
5641
impl<DB: sqlx::Database, E> Clone for Layer<DB, E> {
5742
fn clone(&self) -> Self {
5843
Self {
59-
pool: self.pool.clone(),
44+
state: self.state.clone(),
6045
_error: self._error,
6146
}
6247
}
6348
}
6449

65-
impl<DB: sqlx::Database, S, E> tower_layer::Layer<S> for Layer<DB, E> {
50+
impl<DB: sqlx::Database, S, E> tower_layer::Layer<S> for Layer<DB, E>
51+
where
52+
E: IntoResponse,
53+
sqlx::Error: Into<E>,
54+
{
6655
type Service = Service<DB, S, E>;
6756

6857
fn layer(&self, inner: S) -> Self::Service {
6958
Service {
70-
pool: self.pool.clone(),
59+
state: self.state.clone(),
7160
inner,
7261
_error: self._error,
7362
}
@@ -77,8 +66,8 @@ impl<DB: sqlx::Database, S, E> tower_layer::Layer<S> for Layer<DB, E> {
7766
/// A [`tower_service::Service`] that enables the [`Tx`](crate::Tx) extractor.
7867
///
7968
/// See [`Layer`] for more information.
80-
pub struct Service<DB: sqlx::Database, S, E = Error> {
81-
pool: sqlx::Pool<DB>,
69+
pub struct Service<DB: sqlx::Database, S, E> {
70+
state: State<DB>,
8271
inner: S,
8372
_error: PhantomData<E>,
8473
}
@@ -87,7 +76,7 @@ pub struct Service<DB: sqlx::Database, S, E = Error> {
8776
impl<DB: sqlx::Database, S: Clone, E> Clone for Service<DB, S, E> {
8877
fn clone(&self) -> Self {
8978
Self {
90-
pool: self.pool.clone(),
79+
state: self.state.clone(),
9180
inner: self.inner.clone(),
9281
_error: self._error,
9382
}
@@ -103,7 +92,8 @@ where
10392
Error = std::convert::Infallible,
10493
>,
10594
S::Future: Send + 'static,
106-
E: From<Error> + IntoResponse,
95+
E: IntoResponse,
96+
sqlx::Error: Into<E>,
10797
ResBody: Body<Data = Bytes> + Send + 'static,
10898
ResBody::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
10999
{
@@ -119,7 +109,7 @@ where
119109
}
120110

121111
fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
122-
let transaction = TxSlot::bind(req.extensions_mut(), self.pool.clone());
112+
let transaction = TxSlot::bind(req.extensions_mut(), self.state.clone());
123113

124114
let res = self.inner.call(req);
125115

@@ -128,7 +118,7 @@ where
128118

129119
if !res.status().is_server_error() && !res.status().is_client_error() {
130120
if let Err(error) = transaction.commit().await {
131-
return Ok(E::from(Error::Database { error }).into_response());
121+
return Ok(error.into().into_response());
132122
}
133123
}
134124

@@ -139,17 +129,21 @@ where
139129

140130
#[cfg(test)]
141131
mod tests {
132+
use crate::{Error, State};
133+
142134
use super::Layer;
143135

144136
// The trait shenanigans required by axum for layers are significant, so this "test" ensures
145137
// we've got it right.
146138
#[allow(unused, unreachable_code, clippy::diverging_sub_expression)]
147139
fn layer_compiles() {
148-
let pool: sqlx::Pool<sqlx::Sqlite> = todo!();
140+
let state: State<sqlx::Sqlite> = todo!();
141+
142+
let layer = Layer::<_, Error>::new(state);
149143

150144
let app = axum::Router::new()
151145
.route("/", axum::routing::get(|| async { "hello" }))
152-
.layer(Layer::new(pool));
146+
.layer(layer);
153147

154148
axum::Server::bind(todo!()).serve(app.into_make_service());
155149
}

0 commit comments

Comments
 (0)