diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..e791ec1 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,53 @@ +use std::marker::PhantomData; + +use crate::{Layer, Marker, State}; + +/// Configuration for [`Tx`](crate::Tx) extractors. +/// +/// Use `Config` to configure and create a [`State`] and [`Layer`]. +/// +/// Access the `Config` API from [`Tx::config`](crate::Tx::config). +/// +/// ``` +/// # async fn foo() { +/// # let pool: sqlx::SqlitePool = todo!(); +/// type Tx = axum_sqlx_tx::Tx; +/// +/// let config = Tx::config(pool); +/// # } +/// ``` +pub struct Config { + pool: sqlx::Pool, + _layer_error: PhantomData, +} + +impl Config +where + LayerError: axum_core::response::IntoResponse, + sqlx::Error: Into, +{ + pub(crate) fn new(pool: sqlx::Pool) -> Self { + Self { + pool, + _layer_error: PhantomData, + } + } + + /// Change the layer error type. + pub fn layer_error(self) -> Config + where + sqlx::Error: Into, + { + Config { + pool: self.pool, + _layer_error: PhantomData, + } + } + + /// Create a [`State`] and [`Layer`] to enable the [`Tx`](crate::Tx) extractor. + pub fn setup(self) -> (State, Layer) { + let state = State::new(self.pool); + let layer = Layer::new(state.clone()); + (state, layer) + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..014ea2c --- /dev/null +++ b/src/error.rs @@ -0,0 +1,105 @@ +/// Possible errors when extracting [`Tx`] from a request. +/// +/// Errors can occur at two points during the request lifecycle: +/// +/// 1. The [`Tx`] extractor might fail to obtain a connection from the pool and `BEGIN` a +/// transaction. This could be due to: +/// +/// - Forgetting to add the middleware: [`Error::MissingExtension`]. +/// - Calling the extractor multiple times in the same request: [`Error::OverlappingExtractors`]. +/// - A problem communicating with the database: [`Error::Database`]. +/// +/// 2. The middleware [`Layer`] might fail to commit the transaction. This could be due to a problem +/// communicating with the database, or else a logic error (e.g. unsatisfied deferred +/// constraint): [`Error::Database`]. +/// +/// `axum` requires that errors can be turned into responses. The [`Error`] type converts into a +/// HTTP 500 response with the error message as the response body. This may be suitable for +/// development or internal services but it's generally not advisable to return internal error +/// details to clients. +/// +/// You can override the error types for both the [`Tx`] extractor and [`Layer`]: +/// +/// - Override the [`Tx`]`` error type using the `E` generic type parameter. `E` must be +/// convertible from [`Error`] (e.g. [`Error`]`: Into`). +/// +/// - Override the [`Layer`] error type using [`Config::layer_error`](crate::Config::layer_error). +/// The layer error type must be convertible from `sqlx::Error` (e.g. +/// `sqlx::Error: Into`). +/// +/// In both cases, the error type must implement `axum::response::IntoResponse`. +/// +/// ``` +/// use axum::{response::IntoResponse, routing::post}; +/// +/// enum MyError{ +/// Extractor(axum_sqlx_tx::Error), +/// Layer(sqlx::Error), +/// } +/// +/// impl From for MyError { +/// fn from(error: axum_sqlx_tx::Error) -> Self { +/// Self::Extractor(error) +/// } +/// } +/// +/// impl From for MyError { +/// fn from(error: sqlx::Error) -> Self { +/// Self::Layer(error) +/// } +/// } +/// +/// impl IntoResponse for MyError { +/// fn into_response(self) -> axum::response::Response { +/// // note that you would probably want to log the error as well +/// (http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response() +/// } +/// } +/// +/// // Override the `Tx` error type using the second generic type parameter +/// type Tx = axum_sqlx_tx::Tx; +/// +/// # async fn foo() { +/// let pool = sqlx::SqlitePool::connect("...").await.unwrap(); +/// +/// let (state, layer) = Tx::config(pool) +/// // Override the `Layer` error type using the `Config` API +/// .layer_error::() +/// .setup(); +/// # let app = axum::Router::new() +/// # .route("/", post(create_user)) +/// # .layer(layer) +/// # .with_state(state); +/// # axum::Server::bind(todo!()).serve(app.into_make_service()); +/// # } +/// # async fn create_user(mut tx: Tx, /* ... */) { +/// # /* ... */ +/// # } +/// ``` +/// +/// [`Tx`]: crate::Tx +/// [`Layer`]: crate::Layer +#[derive(Debug, thiserror::Error)] +pub enum Error { + /// Indicates that the [`Layer`](crate::Layer) middleware was not installed. + #[error("required extension not registered; did you add the axum_sqlx_tx::Layer middleware?")] + MissingExtension, + + /// Indicates that [`Tx`](crate::Tx) was extracted multiple times in a single + /// handler/middleware. + #[error("axum_sqlx_tx::Tx extractor used multiple times in the same handler/middleware")] + OverlappingExtractors, + + /// A database error occurred when starting or committing the transaction. + #[error(transparent)] + Database { + #[from] + error: sqlx::Error, + }, +} + +impl axum_core::response::IntoResponse for Error { + fn into_response(self) -> axum_core::response::Response { + (http::StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response() + } +} diff --git a/src/layer.rs b/src/layer.rs index b0ebf6d..b5b4527 100644 --- a/src/layer.rs +++ b/src/layer.rs @@ -7,7 +7,7 @@ use bytes::Bytes; use futures_core::future::BoxFuture; use http_body::{combinators::UnsyncBoxBody, Body}; -use crate::{tx::TxSlot, State}; +use crate::{tx::TxSlot, Marker, State}; /// A [`tower_layer::Layer`] that enables the [`Tx`] extractor. /// @@ -20,12 +20,12 @@ use crate::{tx::TxSlot, State}; /// /// [`Tx`]: crate::Tx /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html -pub struct Layer { +pub struct Layer { state: State, _error: PhantomData, } -impl Layer +impl Layer where E: IntoResponse, sqlx::Error: Into, @@ -38,7 +38,7 @@ where } } -impl Clone for Layer { +impl Clone for Layer { fn clone(&self) -> Self { Self { state: self.state.clone(), @@ -47,7 +47,7 @@ impl Clone for Layer { } } -impl tower_layer::Layer for Layer +impl tower_layer::Layer for Layer where E: IntoResponse, sqlx::Error: Into, @@ -66,14 +66,14 @@ where /// A [`tower_service::Service`] that enables the [`Tx`](crate::Tx) extractor. /// /// See [`Layer`] for more information. -pub struct Service { +pub struct Service { state: State, inner: S, _error: PhantomData, } // can't simply derive because `DB` isn't `Clone` -impl Clone for Service { +impl Clone for Service { fn clone(&self) -> Self { Self { state: self.state.clone(), @@ -83,7 +83,7 @@ impl Clone for Service { } } -impl tower_service::Service> +impl tower_service::Service> for Service where S: tower_service::Service< diff --git a/src/lib.rs b/src/lib.rs index e7f3545..49607b9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -67,6 +67,14 @@ //! //! See [`Error`] for how to customise error handling. //! +//! ## Multiple databases +//! +//! If you need to work with multiple databases, you can define marker structs for each. See +//! [`Marker`] for an example. +//! +//! It's not currently possible to use `Tx` for a dynamic number of databases. Feel free to open an +//! issue if you have a requirement for this. +//! //! # Examples //! //! See [`examples/`][examples] in the repo for more examples. @@ -75,194 +83,19 @@ #![cfg_attr(doc, deny(warnings))] +mod config; +mod error; mod layer; +mod marker; mod slot; +mod state; mod tx; -use std::marker::PhantomData; - pub use crate::{ + config::Config, + error::Error, layer::{Layer, Service}, + marker::Marker, + state::State, tx::Tx, }; - -/// Configuration for [`Tx`] extractors. -/// -/// Use `Config` to configure and create a [`State`] and [`Layer`]. -/// -/// Access the `Config` API from [`Tx::config`]. -/// -/// ``` -/// # async fn foo() { -/// # let pool: sqlx::SqlitePool = todo!(); -/// type Tx = axum_sqlx_tx::Tx; -/// -/// let config = Tx::config(pool); -/// # } -/// ``` -pub struct Config { - pool: sqlx::Pool, - _layer_error: PhantomData, -} - -impl Config -where - LayerError: axum_core::response::IntoResponse, - sqlx::Error: Into, -{ - fn new(pool: sqlx::Pool) -> Self { - Self { - pool, - _layer_error: PhantomData, - } - } - - /// Change the layer error type. - pub fn layer_error(self) -> Config - where - sqlx::Error: Into, - { - Config { - pool: self.pool, - _layer_error: PhantomData, - } - } - - /// Create a [`State`] and [`Layer`] to enable the [`Tx`] extractor. - pub fn setup(self) -> (State, Layer) { - let state = State::new(self.pool); - let layer = Layer::new(state.clone()); - (state, layer) - } -} - -/// Application state that enables the [`Tx`] extractor. -/// -/// `State` must be provided to `Router`s in order to use the [`Tx`] extractor, or else attempting -/// to use the `Router` will not compile. -/// -/// `State` is constructed via [`Tx::setup`] or [`Config::setup`], which also return a middleware -/// [`Layer`]. The state and the middleware together enable the [`Tx`] extractor to work. -#[derive(Debug)] -pub struct State { - pool: sqlx::Pool, -} - -impl State { - pub(crate) fn new(pool: sqlx::Pool) -> Self { - Self { pool } - } - - pub(crate) async fn transaction(&self) -> Result, sqlx::Error> { - self.pool.begin().await - } -} - -impl Clone for State { - fn clone(&self) -> Self { - Self { - pool: self.pool.clone(), - } - } -} - -/// Possible errors when extracting [`Tx`] from a request. -/// -/// Errors can occur at two points during the request lifecycle: -/// -/// 1. The [`Tx`] extractor might fail to obtain a connection from the pool and `BEGIN` a -/// transaction. This could be due to: -/// -/// - Forgetting to add the middleware: [`Error::MissingExtension`]. -/// - Calling the extractor multiple times in the same request: [`Error::OverlappingExtractors`]. -/// - A problem communicating with the database: [`Error::Database`]. -/// -/// 2. The middleware [`Layer`] might fail to commit the transaction. This could be due to a problem -/// communicating with the database, or else a logic error (e.g. unsatisfied deferred -/// constraint): [`Error::Database`]. -/// -/// `axum` requires that errors can be turned into responses. The [`Error`] type converts into a -/// HTTP 500 response with the error message as the response body. This may be suitable for -/// development or internal services but it's generally not advisable to return internal error -/// details to clients. -/// -/// You can override the error types for both the [`Tx`] extractor and [`Layer`]: -/// -/// - Override the [`Tx`]`` error type using the `E` generic type parameter. `E` must be -/// convertible from [`Error`] (e.g. [`Error`]`: Into`). -/// -/// - Override the [`Layer`] error type using [`Config::layer_error`]. The layer error type must be -/// convertible from `sqlx::Error` (e.g. `sqlx::Error: Into`). -/// -/// In both cases, the error type must implement `axum::response::IntoResponse`. -/// -/// ``` -/// use axum::{response::IntoResponse, routing::post}; -/// -/// enum MyError{ -/// Extractor(axum_sqlx_tx::Error), -/// Layer(sqlx::Error), -/// } -/// -/// impl From for MyError { -/// fn from(error: axum_sqlx_tx::Error) -> Self { -/// Self::Extractor(error) -/// } -/// } -/// -/// impl From for MyError { -/// fn from(error: sqlx::Error) -> Self { -/// Self::Layer(error) -/// } -/// } -/// -/// impl IntoResponse for MyError { -/// fn into_response(self) -> axum::response::Response { -/// // note that you would probably want to log the error as well -/// (http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response() -/// } -/// } -/// -/// // Override the `Tx` error type using the second generic type parameter -/// type Tx = axum_sqlx_tx::Tx; -/// -/// # async fn foo() { -/// let pool = sqlx::SqlitePool::connect("...").await.unwrap(); -/// -/// let (state, layer) = Tx::config(pool) -/// // Override the `Layer` error type using the `Config` API -/// .layer_error::() -/// .setup(); -/// # let app = axum::Router::new() -/// # .route("/", post(create_user)) -/// # .layer(layer) -/// # .with_state(state); -/// # axum::Server::bind(todo!()).serve(app.into_make_service()); -/// # } -/// # async fn create_user(mut tx: Tx, /* ... */) { -/// # /* ... */ -/// # } -/// ``` -#[derive(Debug, thiserror::Error)] -pub enum Error { - /// Indicates that the [`Layer`] middleware was not installed. - #[error("required extension not registered; did you add the axum_sqlx_tx::Layer middleware?")] - MissingExtension, - - /// Indicates that [`Tx`] was extracted multiple times in a single handler/middleware. - #[error("axum_sqlx_tx::Tx extractor used multiple times in the same handler/middleware")] - OverlappingExtractors, - - /// A database error occurred when starting or committing the transaction. - #[error(transparent)] - Database { - #[from] - error: sqlx::Error, - }, -} - -impl axum_core::response::IntoResponse for Error { - fn into_response(self) -> axum_core::response::Response { - (http::StatusCode::INTERNAL_SERVER_ERROR, self.to_string()).into_response() - } -} diff --git a/src/marker.rs b/src/marker.rs new file mode 100644 index 0000000..11eccbd --- /dev/null +++ b/src/marker.rs @@ -0,0 +1,76 @@ +use std::fmt::Debug; + +/// Extractor marker type. +/// +/// Since the [`Tx`](crate::Tx) extractor operates at the type level, a generic type parameter is +/// used to identify different databases. +/// +/// There is a blanket implementation for all implementors of [`sqlx::Database`], but you can create +/// your own types if you need to work with multiple databases. +/// +/// ``` +/// // Marker struct "database 1" +/// #[derive(Debug)] +/// struct Db1; +/// +/// impl axum_sqlx_tx::Marker for Db1 { +/// type Driver = sqlx::Sqlite; +/// } +/// +/// // Marker struct "database 2" +/// #[derive(Debug)] +/// struct Db2; +/// +/// impl axum_sqlx_tx::Marker for Db2 { +/// type Driver = sqlx::Sqlite; +/// } +/// +/// // You'll also need a "state" structure that implements `FromRef` for each `State` +/// #[derive(Clone)] +/// struct MyState { +/// state1: axum_sqlx_tx::State, +/// state2: axum_sqlx_tx::State, +/// } +/// +/// impl axum::extract::FromRef for axum_sqlx_tx::State { +/// fn from_ref(state: &MyState) -> Self { +/// state.state1.clone() +/// } +/// } +/// +/// impl axum::extract::FromRef for axum_sqlx_tx::State { +/// fn from_ref(state: &MyState) -> Self { +/// state.state2.clone() +/// } +/// } +/// +/// // The extractor can then be aliased for each DB +/// type Tx1 = axum_sqlx_tx::Tx; +/// type Tx2 = axum_sqlx_tx::Tx; +/// +/// # async fn foo() { +/// // Setup each extractor +/// let pool1 = sqlx::SqlitePool::connect("...").await.unwrap(); +/// let (state1, layer1) = Tx1::setup(pool1); +/// +/// let pool2 = sqlx::SqlitePool::connect("...").await.unwrap(); +/// let (state2, layer2) = Tx2::setup(pool2); +/// +/// let app = axum::Router::new() +/// .route("/", axum::routing::get(|tx1: Tx1, tx2: Tx2| async move { +/// /* ... */ +/// })) +/// .layer(layer1) +/// .layer(layer2) +/// .with_state(MyState { state1, state2 }); +/// # axum::Server::bind(todo!()).serve(app.into_make_service()); +/// # } +/// ``` +pub trait Marker: Debug + Send + Sized + 'static { + /// The `sqlx` database driver. + type Driver: sqlx::Database; +} + +impl Marker for DB { + type Driver = Self; +} diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..1e4d019 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,36 @@ +use crate::Marker; + +/// Application state that enables the [`Tx`] extractor. +/// +/// `State` must be provided to `Router`s in order to use the [`Tx`] extractor, or else attempting +/// to use the `Router` will not compile. +/// +/// `State` is constructed via [`Tx::setup`](crate::Tx::setup) or +/// [`Config::setup`](crate::Config::setup), which also return a middleware [`Layer`](crate::Layer). +/// The state and the middleware together enable the [`Tx`] extractor to work. +/// +/// [`Tx`]: crate::Tx +#[derive(Debug)] +pub struct State { + pool: sqlx::Pool, +} + +impl State { + pub(crate) fn new(pool: sqlx::Pool) -> Self { + Self { pool } + } + + pub(crate) async fn transaction( + &self, + ) -> Result, sqlx::Error> { + self.pool.begin().await + } +} + +impl Clone for State { + fn clone(&self) -> Self { + Self { + pool: self.pool.clone(), + } + } +} diff --git a/src/tx.rs b/src/tx.rs index 003dfd2..a013a90 100644 --- a/src/tx.rs +++ b/src/tx.rs @@ -12,7 +12,7 @@ use sqlx::Transaction; use crate::{ slot::{Lease, Slot}, - Config, Error, State, + Config, Error, Marker, State, }; /// An `axum` extractor for a database transaction. @@ -74,12 +74,12 @@ use crate::{ /// } /// ``` #[derive(Debug)] -pub struct Tx { - tx: Lease>, +pub struct Tx { + tx: Lease>, _error: PhantomData, } -impl Tx { +impl Tx { /// Crate a [`State`] and [`Layer`](crate::Layer) to enable the extractor. /// /// This is convenient to use from a type alias, e.g. @@ -92,7 +92,7 @@ impl Tx { /// let (state, layer) = Tx::setup(pool); /// # } /// ``` - pub fn setup(pool: sqlx::Pool) -> (State, crate::Layer) { + pub fn setup(pool: sqlx::Pool) -> (State, crate::Layer) { Config::new(pool).setup() } @@ -110,7 +110,7 @@ impl Tx { /// let config = Tx::config(pool); /// # } /// ``` - pub fn config(pool: sqlx::Pool) -> Config { + pub fn config(pool: sqlx::Pool) -> Config { Config::new(pool) } @@ -127,33 +127,33 @@ impl Tx { } } -impl AsRef> for Tx { - fn as_ref(&self) -> &sqlx::Transaction<'static, DB> { +impl AsRef> for Tx { + fn as_ref(&self) -> &sqlx::Transaction<'static, DB::Driver> { &self.tx } } -impl AsMut> for Tx { - fn as_mut(&mut self) -> &mut sqlx::Transaction<'static, DB> { +impl AsMut> for Tx { + fn as_mut(&mut self) -> &mut sqlx::Transaction<'static, DB::Driver> { &mut self.tx } } -impl std::ops::Deref for Tx { - type Target = sqlx::Transaction<'static, DB>; +impl std::ops::Deref for Tx { + type Target = sqlx::Transaction<'static, DB::Driver>; fn deref(&self) -> &Self::Target { &self.tx } } -impl std::ops::DerefMut for Tx { +impl std::ops::DerefMut for Tx { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.tx } } -impl FromRequestParts for Tx +impl FromRequestParts for Tx where E: From + IntoResponse, State: FromRef, @@ -183,9 +183,9 @@ where } /// The OG `Slot` – the transaction (if any) returns here when the `Extension` is dropped. -pub(crate) struct TxSlot(Slot>>>); +pub(crate) struct TxSlot(Slot>>>); -impl TxSlot { +impl TxSlot { /// Create a `TxSlot` bound to the given request extensions. /// /// When the request extensions are dropped, `commit` can be called to commit the transaction @@ -208,13 +208,13 @@ impl TxSlot { /// /// When the transaction is started, it's inserted into the `Option` leased from the `TxSlot`, so /// that when `Lazy` is dropped the transaction is moved to the `TxSlot`. -struct Lazy { +struct Lazy { state: State, - tx: Lease>>>, + tx: Lease>>>, } -impl Lazy { - async fn get_or_begin(&mut self) -> Result>, Error> { +impl Lazy { + async fn get_or_begin(&mut self) -> Result>, Error> { let tx = if let Some(tx) = self.tx.as_mut() { tx } else { @@ -228,11 +228,12 @@ impl Lazy { impl<'c, DB, E> sqlx::Executor<'c> for &'c mut Tx where - DB: sqlx::Database, - for<'t> &'t mut DB::Connection: sqlx::Executor<'t, Database = DB>, + DB: Marker, + for<'t> &'t mut ::Connection: + sqlx::Executor<'t, Database = DB::Driver>, E: std::fmt::Debug + Send, { - type Database = DB; + type Database = DB::Driver; #[allow(clippy::type_complexity)] fn fetch_many<'e, 'q: 'e, Q: 'q>( diff --git a/tests/lib.rs b/tests/lib.rs index b03a041..7351dfb 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -260,6 +260,93 @@ async fn layer_error_override() { assert_eq!(body, "internal server error"); } +#[tokio::test] +async fn multi_db() { + #[derive(Debug)] + struct DbA; + impl axum_sqlx_tx::Marker for DbA { + type Driver = sqlx::Sqlite; + } + type TxA = axum_sqlx_tx::Tx; + + #[derive(Debug)] + struct DbB; + impl axum_sqlx_tx::Marker for DbB { + type Driver = sqlx::Sqlite; + } + type TxB = axum_sqlx_tx::Tx; + + let pool_a = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); + let pool_b = sqlx::SqlitePool::connect("sqlite::memory:").await.unwrap(); + + sqlx::query("CREATE TABLE IF NOT EXISTS users (id INT PRIMARY KEY);") + .execute(&pool_a) + .await + .unwrap(); + sqlx::query( + r#" + CREATE TABLE IF NOT EXISTS comments ( + id INT PRIMARY KEY, + user_id INT + );"#, + ) + .execute(&pool_b) + .await + .unwrap(); + + let (state_a, layer_a) = TxA::setup(pool_a); + let (state_b, layer_b) = TxB::setup(pool_b); + + #[derive(Clone)] + struct State { + state_a: axum_sqlx_tx::State, + state_b: axum_sqlx_tx::State, + } + + impl axum::extract::FromRef for axum_sqlx_tx::State { + fn from_ref(input: &State) -> Self { + input.state_a.clone() + } + } + + impl axum::extract::FromRef for axum_sqlx_tx::State { + fn from_ref(input: &State) -> Self { + input.state_b.clone() + } + } + + let app = axum::Router::new() + .route( + "/", + axum::routing::get(|mut tx_a: TxA, mut tx_b: TxB| async move { + sqlx::query("SELECT * FROM users") + .execute(&mut tx_a) + .await + .unwrap(); + sqlx::query("SELECT * FROM comments") + .execute(&mut tx_b) + .await + .unwrap(); + }), + ) + .layer(layer_a) + .layer(layer_b) + .with_state(State { state_a, state_b }); + + let response = app + .oneshot( + http::Request::builder() + .uri("/") + .body(axum::body::Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + let status = response.status(); + + assert!(status.is_success()); +} + async fn insert_user(tx: &mut Tx, id: i32, name: &str) -> (i32, String) { let mut args = SqliteArguments::default(); args.add(id);