Skip to content

Commit 9d02291

Browse files
author
Chris Connelly
authored
Merge pull request #8 from wasdacraic/error-handling
Error handling
2 parents e3f6f3f + 0ba85c9 commit 9d02291

File tree

6 files changed

+291
-123
lines changed

6 files changed

+291
-123
lines changed

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22
name = "axum-sqlx-tx"
33
description = "Request-scoped SQLx transactions for axum"
4-
version = "0.1.2"
4+
version = "0.2.0"
55
license = "MIT"
66
repository = "https://github.com/wasdacraic/axum-sqlx-tx/"
77
edition = "2021"
@@ -28,8 +28,10 @@ features = ["all-databases", "runtime-tokio-rustls"]
2828

2929
[dependencies]
3030
axum-core = "0.1.2"
31+
bytes = "1.1.0"
3132
futures-core = "0.3.21"
3233
http = "0.2.6"
34+
http-body = "0.4.4"
3335
parking_lot = "0.12.0"
3436
sqlx = { version = "0.5.11", default-features = false }
3537
thiserror = "1.0.30"

examples/example.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
33
use std::error::Error;
44

5-
use axum::{error_handling::HandleErrorLayer, response::IntoResponse, routing::get, Json};
5+
use axum::{response::IntoResponse, routing::get, Json};
66
use http::StatusCode;
77

88
// OPTIONAL: use a type alias to avoid repeating your database type
@@ -22,16 +22,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
2222
// Standard axum app setup
2323
let app = axum::Router::new()
2424
.route("/numbers", get(list_numbers).post(generate_number))
25-
.layer(
26-
// Apply the Tx middleware
27-
tower::ServiceBuilder::new()
28-
.layer(HandleErrorLayer::new(|error: sqlx::Error| async move {
29-
// The transaction is committed by the middleware so an error is possible
30-
// and must be converted into a response.
31-
DbError(error)
32-
}))
33-
.layer(axum_sqlx_tx::Layer::new(pool.clone())),
34-
);
25+
// Apply the Tx middleware
26+
.layer(axum_sqlx_tx::Layer::new(pool.clone()));
3527

3628
let server = axum::Server::bind(&([0, 0, 0, 0], 0).into()).serve(app.into_make_service());
3729

src/layer.rs

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
//! A [`tower_layer::Layer`] that enables the [`Tx`](crate::Tx) extractor.
22
3+
use std::marker::PhantomData;
4+
5+
use axum_core::response::IntoResponse;
6+
use bytes::Bytes;
37
use futures_core::future::BoxFuture;
8+
use http_body::{combinators::UnsyncBoxBody, Body};
49

5-
use crate::tx::TxSlot;
10+
use crate::{tx::TxSlot, Error};
611

712
/// A [`tower_layer::Layer`] that enables the [`Tx`] extractor.
813
///
@@ -15,8 +20,9 @@ use crate::tx::TxSlot;
1520
///
1621
/// [`Tx`]: crate::Tx
1722
/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
18-
pub struct Layer<DB: sqlx::Database> {
23+
pub struct Layer<DB: sqlx::Database, E = Error> {
1924
pool: sqlx::Pool<DB>,
25+
_error: PhantomData<E>,
2026
}
2127

2228
impl<DB: sqlx::Database> Layer<DB> {
@@ -28,54 +34,72 @@ impl<DB: sqlx::Database> Layer<DB> {
2834
/// If you want to access the pool outside of a transaction, you should add it also with
2935
/// [`axum::Extension`].
3036
///
37+
/// To use a different type than [`Error`] to convert commit errors into responses, see
38+
/// [`new_with_error`](Self::new_with_error).
39+
///
3140
/// [`axum::Extension`]: https://docs.rs/axum/latest/axum/extract/struct.Extension.html
3241
pub fn new(pool: sqlx::Pool<DB>) -> Self {
33-
Self { pool }
42+
Self::new_with_error(pool)
43+
}
44+
45+
/// Construct a new layer with a specific error type.
46+
///
47+
/// See [`Layer::new`] for more information.
48+
pub fn new_with_error<E>(pool: sqlx::Pool<DB>) -> Layer<DB, E> {
49+
Layer {
50+
pool,
51+
_error: PhantomData,
52+
}
3453
}
3554
}
3655

37-
impl<DB: sqlx::Database, S> tower_layer::Layer<S> for Layer<DB> {
38-
type Service = Service<DB, S>;
56+
impl<DB: sqlx::Database, S, E> tower_layer::Layer<S> for Layer<DB, E> {
57+
type Service = Service<DB, S, E>;
3958

4059
fn layer(&self, inner: S) -> Self::Service {
4160
Service {
4261
pool: self.pool.clone(),
4362
inner,
63+
_error: self._error,
4464
}
4565
}
4666
}
4767

4868
/// A [`tower_service::Service`] that enables the [`Tx`](crate::Tx) extractor.
4969
///
5070
/// See [`Layer`] for more information.
51-
pub struct Service<DB: sqlx::Database, S> {
71+
pub struct Service<DB: sqlx::Database, S, E = Error> {
5272
pool: sqlx::Pool<DB>,
5373
inner: S,
74+
_error: PhantomData<E>,
5475
}
5576

5677
// can't simply derive because `DB` isn't `Clone`
57-
impl<DB: sqlx::Database, S: Clone> Clone for Service<DB, S> {
78+
impl<DB: sqlx::Database, S: Clone, E> Clone for Service<DB, S, E> {
5879
fn clone(&self) -> Self {
5980
Self {
6081
pool: self.pool.clone(),
6182
inner: self.inner.clone(),
83+
_error: self._error,
6284
}
6385
}
6486
}
6587

66-
impl<DB: sqlx::Database, S, ReqBody, ResBody> tower_service::Service<http::Request<ReqBody>>
67-
for Service<DB, S>
88+
impl<DB: sqlx::Database, S, E, ReqBody, ResBody> tower_service::Service<http::Request<ReqBody>>
89+
for Service<DB, S, E>
6890
where
6991
S: tower_service::Service<
7092
http::Request<ReqBody>,
7193
Response = http::Response<ResBody>,
7294
Error = std::convert::Infallible,
7395
>,
7496
S::Future: Send + 'static,
75-
ResBody: Send,
97+
E: From<Error> + IntoResponse,
98+
ResBody: Body<Data = Bytes> + Send + 'static,
99+
ResBody::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
76100
{
77-
type Response = S::Response;
78-
type Error = sqlx::Error;
101+
type Response = http::Response<UnsyncBoxBody<ResBody::Data, axum_core::Error>>;
102+
type Error = S::Error;
79103
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
80104

81105
fn poll_ready(
@@ -94,19 +118,18 @@ where
94118
let res = res.await.unwrap(); // inner service is infallible
95119

96120
if res.status().is_success() {
97-
transaction.commit().await?;
121+
if let Err(error) = transaction.commit().await {
122+
return Ok(E::from(Error::Database { error }).into_response());
123+
}
98124
}
99125

100-
Ok(res)
126+
Ok(res.map(|body| body.map_err(axum_core::Error::new).boxed_unsync()))
101127
})
102128
}
103129
}
104130

105131
#[cfg(test)]
106132
mod tests {
107-
use axum::error_handling::HandleErrorLayer;
108-
use tower::ServiceBuilder;
109-
110133
use super::Layer;
111134

112135
// The trait shenanigans required by axum for layers are significant, so this "test" ensures
@@ -117,11 +140,7 @@ mod tests {
117140

118141
let app = axum::Router::new()
119142
.route("/", axum::routing::get(|| async { "hello" }))
120-
.layer(
121-
ServiceBuilder::new()
122-
.layer(HandleErrorLayer::new(|_: sqlx::Error| async {}))
123-
.layer(Layer::new(pool)),
124-
);
143+
.layer(Layer::new(pool));
125144

126145
axum::Server::bind(todo!()).serve(app.into_make_service());
127146
}

src/lib.rs

Lines changed: 110 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,30 +19,20 @@
1919
//!
2020
//! To use the [`Tx`] extractor, you must first add [`Layer`] to your app:
2121
//!
22-
//! ```no_run
23-
//! # use axum::error_handling::HandleErrorLayer;
22+
//! ```
2423
//! # async fn foo() {
2524
//! let pool = /* any sqlx::Pool */
2625
//! # sqlx::SqlitePool::connect(todo!()).await.unwrap();
2726
//! let app = axum::Router::new()
2827
//! // .route(...)s
29-
//! .layer(
30-
//! tower::ServiceBuilder::new()
31-
//! // The transaction is committed by the middleware so an error is possible and must
32-
//! // be converted into a response
33-
//! .layer(HandleErrorLayer::new(|error: sqlx::Error| async move {
34-
//! http::StatusCode::INTERNAL_SERVER_ERROR
35-
//! }))
36-
//! // Now we can add the middleware
37-
//! .layer(axum_sqlx_tx::Layer::new(pool)),
38-
//! );
28+
//! .layer(axum_sqlx_tx::Layer::new(pool));
3929
//! # axum::Server::bind(todo!()).serve(app.into_make_service());
4030
//! # }
4131
//! ```
4232
//!
4333
//! You can then simply add [`Tx`] as an argument to your handlers:
4434
//!
45-
//! ```no_run
35+
//! ```
4636
//! use axum_sqlx_tx::Tx;
4737
//! use sqlx::Sqlite;
4838
//!
@@ -65,6 +55,56 @@
6555
//! you have multiple `Tx` arguments in a single handler, or call `Tx::from_request` multiple times
6656
//! in a single middleware.
6757
//!
58+
//! ## Error handling
59+
//!
60+
//! `axum` requires that middleware do not return errors, and that the errors returned by extractors
61+
//! implement `IntoResponse`. By default, [`Error`](Error) is used by [`Layer`] and [`Tx`] to
62+
//! convert errors into HTTP 500 responses, with the error's `Display` value as the response body,
63+
//! however it's generally not a good practice to return internal error details to clients!
64+
//!
65+
//! To make it easier to customise error handling, both [`Layer`] and [`Tx`] have a second generic
66+
//! type parameter, `E`, that can be used to override the error type that will be used to convert
67+
//! the response.
68+
//!
69+
//! ```
70+
//! use axum::response::IntoResponse;
71+
//! use axum_sqlx_tx::Tx;
72+
//! use sqlx::Sqlite;
73+
//!
74+
//! struct MyError(axum_sqlx_tx::Error);
75+
//!
76+
//! // Errors must implement From<axum_sqlx_tx::Error>
77+
//! impl From<axum_sqlx_tx::Error> for MyError {
78+
//! fn from(error: axum_sqlx_tx::Error) -> Self {
79+
//! Self(error)
80+
//! }
81+
//! }
82+
//!
83+
//! // Errors must implement IntoResponse
84+
//! impl IntoResponse for MyError {
85+
//! fn into_response(self) -> axum::response::Response {
86+
//! // note that you would probably want to log the error or something
87+
//! (http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
88+
//! }
89+
//! }
90+
//!
91+
//! // Change the layer error type
92+
//! # async fn foo() {
93+
//! # let pool: sqlx::SqlitePool = todo!();
94+
//! let app = axum::Router::new()
95+
//! // .route(...)s
96+
//! .layer(axum_sqlx_tx::Layer::new_with_error::<MyError>(pool));
97+
//! # axum::Server::bind(todo!()).serve(app.into_make_service());
98+
//! # }
99+
//!
100+
//! // Change the extractor error type
101+
//! async fn create_user(mut tx: Tx<Sqlite, MyError>, /* ... */) {
102+
//! /* ... */
103+
//! }
104+
//! ```
105+
//!
106+
//! # Examples
107+
//!
68108
//! See [`examples/`][examples] in the repo for more examples.
69109
//!
70110
//! [examples]: https://github.com/wasdacraic/axum-sqlx-tx/tree/master/examples
@@ -77,5 +117,61 @@ mod tx;
77117

78118
pub use crate::{
79119
layer::{Layer, Service},
80-
tx::{Error, Tx},
120+
tx::Tx,
81121
};
122+
123+
/// Possible errors when extracting [`Tx`] from a request.
124+
///
125+
/// `axum` requires that the `FromRequest` `Rejection` implements `IntoResponse`, which this does
126+
/// by returning the `Display` representation of the variant. Note that this means returning
127+
/// configuration and database errors to clients, but you can override the type of error that
128+
/// `Tx::from_request` returns using the `E` generic parameter:
129+
///
130+
/// ```
131+
/// use axum::response::IntoResponse;
132+
/// use axum_sqlx_tx::Tx;
133+
/// use sqlx::Sqlite;
134+
///
135+
/// struct MyError(axum_sqlx_tx::Error);
136+
///
137+
/// // The error type must implement From<axum_sqlx_tx::Error>
138+
/// impl From<axum_sqlx_tx::Error> for MyError {
139+
/// fn from(error: axum_sqlx_tx::Error) -> Self {
140+
/// Self(error)
141+
/// }
142+
/// }
143+
///
144+
/// // The error type must implement IntoResponse
145+
/// impl IntoResponse for MyError {
146+
/// fn into_response(self) -> axum::response::Response {
147+
/// (http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
148+
/// }
149+
/// }
150+
///
151+
/// async fn handler(tx: Tx<Sqlite, MyError>) {
152+
/// /* ... */
153+
/// }
154+
/// ```
155+
#[derive(Debug, thiserror::Error)]
156+
pub enum Error {
157+
/// Indicates that the [`Layer`](crate::Layer) middleware was not installed.
158+
#[error("required extension not registered; did you add the axum_sqlx_tx::Layer middleware?")]
159+
MissingExtension,
160+
161+
/// Indicates that [`Tx`] was extracted multiple times in a single handler/middleware.
162+
#[error("axum_sqlx_tx::Tx extractor used multiple times in the same handler/middleware")]
163+
OverlappingExtractors,
164+
165+
/// A database error occurred when starting the transaction.
166+
#[error(transparent)]
167+
Database {
168+
#[from]
169+
error: sqlx::Error,
170+
},
171+
}
172+
173+
impl axum_core::response::IntoResponse for Error {
174+
fn into_response(self) -> axum_core::response::Response {
175+
(http::StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
176+
}
177+
}

0 commit comments

Comments
 (0)