Skip to content

Commit 0325d41

Browse files
authored
Merge pull request #19 from digital-society-coop/multi-db
Add support for multiple database instances
2 parents 5957c32 + 243fc03 commit 0325d41

File tree

8 files changed

+405
-214
lines changed

8 files changed

+405
-214
lines changed

src/config.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
use std::marker::PhantomData;
2+
3+
use crate::{Layer, Marker, State};
4+
5+
/// Configuration for [`Tx`](crate::Tx) extractors.
6+
///
7+
/// Use `Config` to configure and create a [`State`] and [`Layer`].
8+
///
9+
/// Access the `Config` API from [`Tx::config`](crate::Tx::config).
10+
///
11+
/// ```
12+
/// # async fn foo() {
13+
/// # let pool: sqlx::SqlitePool = todo!();
14+
/// type Tx = axum_sqlx_tx::Tx<sqlx::Sqlite>;
15+
///
16+
/// let config = Tx::config(pool);
17+
/// # }
18+
/// ```
19+
pub struct Config<DB: Marker, LayerError> {
20+
pool: sqlx::Pool<DB::Driver>,
21+
_layer_error: PhantomData<LayerError>,
22+
}
23+
24+
impl<DB: Marker, LayerError> Config<DB, LayerError>
25+
where
26+
LayerError: axum_core::response::IntoResponse,
27+
sqlx::Error: Into<LayerError>,
28+
{
29+
pub(crate) fn new(pool: sqlx::Pool<DB::Driver>) -> Self {
30+
Self {
31+
pool,
32+
_layer_error: PhantomData,
33+
}
34+
}
35+
36+
/// Change the layer error type.
37+
pub fn layer_error<E>(self) -> Config<DB, E>
38+
where
39+
sqlx::Error: Into<E>,
40+
{
41+
Config {
42+
pool: self.pool,
43+
_layer_error: PhantomData,
44+
}
45+
}
46+
47+
/// Create a [`State`] and [`Layer`] to enable the [`Tx`](crate::Tx) extractor.
48+
pub fn setup(self) -> (State<DB>, Layer<DB, LayerError>) {
49+
let state = State::new(self.pool);
50+
let layer = Layer::new(state.clone());
51+
(state, layer)
52+
}
53+
}

src/error.rs

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/// Possible errors when extracting [`Tx`] from a request.
2+
///
3+
/// Errors can occur at two points during the request lifecycle:
4+
///
5+
/// 1. The [`Tx`] extractor might fail to obtain a connection from the pool and `BEGIN` a
6+
/// transaction. This could be due to:
7+
///
8+
/// - Forgetting to add the middleware: [`Error::MissingExtension`].
9+
/// - Calling the extractor multiple times in the same request: [`Error::OverlappingExtractors`].
10+
/// - A problem communicating with the database: [`Error::Database`].
11+
///
12+
/// 2. The middleware [`Layer`] might fail to commit the transaction. This could be due to a problem
13+
/// communicating with the database, or else a logic error (e.g. unsatisfied deferred
14+
/// constraint): [`Error::Database`].
15+
///
16+
/// `axum` requires that errors can be turned into responses. The [`Error`] type converts into a
17+
/// HTTP 500 response with the error message as the response body. This may be suitable for
18+
/// development or internal services but it's generally not advisable to return internal error
19+
/// details to clients.
20+
///
21+
/// You can override the error types for both the [`Tx`] extractor and [`Layer`]:
22+
///
23+
/// - Override the [`Tx`]`<DB, E>` error type using the `E` generic type parameter. `E` must be
24+
/// convertible from [`Error`] (e.g. [`Error`]`: Into<E>`).
25+
///
26+
/// - Override the [`Layer`] error type using [`Config::layer_error`](crate::Config::layer_error).
27+
/// The layer error type must be convertible from `sqlx::Error` (e.g.
28+
/// `sqlx::Error: Into<LayerError>`).
29+
///
30+
/// In both cases, the error type must implement `axum::response::IntoResponse`.
31+
///
32+
/// ```
33+
/// use axum::{response::IntoResponse, routing::post};
34+
///
35+
/// enum MyError{
36+
/// Extractor(axum_sqlx_tx::Error),
37+
/// Layer(sqlx::Error),
38+
/// }
39+
///
40+
/// impl From<axum_sqlx_tx::Error> for MyError {
41+
/// fn from(error: axum_sqlx_tx::Error) -> Self {
42+
/// Self::Extractor(error)
43+
/// }
44+
/// }
45+
///
46+
/// impl From<sqlx::Error> for MyError {
47+
/// fn from(error: sqlx::Error) -> Self {
48+
/// Self::Layer(error)
49+
/// }
50+
/// }
51+
///
52+
/// impl IntoResponse for MyError {
53+
/// fn into_response(self) -> axum::response::Response {
54+
/// // note that you would probably want to log the error as well
55+
/// (http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
56+
/// }
57+
/// }
58+
///
59+
/// // Override the `Tx` error type using the second generic type parameter
60+
/// type Tx = axum_sqlx_tx::Tx<sqlx::Sqlite, MyError>;
61+
///
62+
/// # async fn foo() {
63+
/// let pool = sqlx::SqlitePool::connect("...").await.unwrap();
64+
///
65+
/// let (state, layer) = Tx::config(pool)
66+
/// // Override the `Layer` error type using the `Config` API
67+
/// .layer_error::<MyError>()
68+
/// .setup();
69+
/// # let app = axum::Router::new()
70+
/// # .route("/", post(create_user))
71+
/// # .layer(layer)
72+
/// # .with_state(state);
73+
/// # axum::Server::bind(todo!()).serve(app.into_make_service());
74+
/// # }
75+
/// # async fn create_user(mut tx: Tx, /* ... */) {
76+
/// # /* ... */
77+
/// # }
78+
/// ```
79+
///
80+
/// [`Tx`]: crate::Tx
81+
/// [`Layer`]: crate::Layer
82+
#[derive(Debug, thiserror::Error)]
83+
pub enum Error {
84+
/// Indicates that the [`Layer`](crate::Layer) middleware was not installed.
85+
#[error("required extension not registered; did you add the axum_sqlx_tx::Layer middleware?")]
86+
MissingExtension,
87+
88+
/// Indicates that [`Tx`](crate::Tx) was extracted multiple times in a single
89+
/// handler/middleware.
90+
#[error("axum_sqlx_tx::Tx extractor used multiple times in the same handler/middleware")]
91+
OverlappingExtractors,
92+
93+
/// A database error occurred when starting or committing the transaction.
94+
#[error(transparent)]
95+
Database {
96+
#[from]
97+
error: sqlx::Error,
98+
},
99+
}
100+
101+
impl axum_core::response::IntoResponse for Error {
102+
fn into_response(self) -> axum_core::response::Response {
103+
(http::StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response()
104+
}
105+
}

src/layer.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use bytes::Bytes;
77
use futures_core::future::BoxFuture;
88
use http_body::{combinators::UnsyncBoxBody, Body};
99

10-
use crate::{tx::TxSlot, State};
10+
use crate::{tx::TxSlot, Marker, State};
1111

1212
/// A [`tower_layer::Layer`] that enables the [`Tx`] extractor.
1313
///
@@ -20,12 +20,12 @@ use crate::{tx::TxSlot, State};
2020
///
2121
/// [`Tx`]: crate::Tx
2222
/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
23-
pub struct Layer<DB: sqlx::Database, E> {
23+
pub struct Layer<DB: Marker, E> {
2424
state: State<DB>,
2525
_error: PhantomData<E>,
2626
}
2727

28-
impl<DB: sqlx::Database, E> Layer<DB, E>
28+
impl<DB: Marker, E> Layer<DB, E>
2929
where
3030
E: IntoResponse,
3131
sqlx::Error: Into<E>,
@@ -38,7 +38,7 @@ where
3838
}
3939
}
4040

41-
impl<DB: sqlx::Database, E> Clone for Layer<DB, E> {
41+
impl<DB: Marker, E> Clone for Layer<DB, E> {
4242
fn clone(&self) -> Self {
4343
Self {
4444
state: self.state.clone(),
@@ -47,7 +47,7 @@ impl<DB: sqlx::Database, E> Clone for Layer<DB, E> {
4747
}
4848
}
4949

50-
impl<DB: sqlx::Database, S, E> tower_layer::Layer<S> for Layer<DB, E>
50+
impl<DB: Marker, S, E> tower_layer::Layer<S> for Layer<DB, E>
5151
where
5252
E: IntoResponse,
5353
sqlx::Error: Into<E>,
@@ -66,14 +66,14 @@ where
6666
/// A [`tower_service::Service`] that enables the [`Tx`](crate::Tx) extractor.
6767
///
6868
/// See [`Layer`] for more information.
69-
pub struct Service<DB: sqlx::Database, S, E> {
69+
pub struct Service<DB: Marker, S, E> {
7070
state: State<DB>,
7171
inner: S,
7272
_error: PhantomData<E>,
7373
}
7474

7575
// can't simply derive because `DB` isn't `Clone`
76-
impl<DB: sqlx::Database, S: Clone, E> Clone for Service<DB, S, E> {
76+
impl<DB: Marker, S: Clone, E> Clone for Service<DB, S, E> {
7777
fn clone(&self) -> Self {
7878
Self {
7979
state: self.state.clone(),
@@ -83,7 +83,7 @@ impl<DB: sqlx::Database, S: Clone, E> Clone for Service<DB, S, E> {
8383
}
8484
}
8585

86-
impl<DB: sqlx::Database, S, E, ReqBody, ResBody> tower_service::Service<http::Request<ReqBody>>
86+
impl<DB: Marker, S, E, ReqBody, ResBody> tower_service::Service<http::Request<ReqBody>>
8787
for Service<DB, S, E>
8888
where
8989
S: tower_service::Service<

0 commit comments

Comments
 (0)