@@ -7,7 +7,7 @@ use bytes::Bytes;
77use futures_core:: future:: BoxFuture ;
88use http_body:: { combinators:: UnsyncBoxBody , Body } ;
99
10- use crate :: { tx:: TxSlot , Error } ;
10+ use crate :: { tx:: TxSlot , State } ;
1111
1212/// A [`tower_layer::Layer`] that enables the [`Tx`] extractor.
1313///
@@ -20,34 +20,19 @@ use crate::{tx::TxSlot, Error};
2020///
2121/// [`Tx`]: crate::Tx
2222/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
23- pub struct Layer < DB : sqlx:: Database , E = Error > {
24- pool : sqlx :: Pool < DB > ,
23+ pub struct Layer < DB : sqlx:: Database , E > {
24+ state : State < DB > ,
2525 _error : PhantomData < E > ,
2626}
2727
28- impl < DB : sqlx:: Database > Layer < DB > {
29- /// Construct a new layer with the given `pool`.
30- ///
31- /// A connection will be obtained from the pool the first time a [`Tx`](crate::Tx) is extracted
32- /// from a request.
33- ///
34- /// If you want to access the pool outside of a transaction, you should add it also with
35- /// [`axum::Extension`].
36- ///
37- /// To use a different type than [`Error`] to convert commit errors into responses, see
38- /// [`new_with_error`](Self::new_with_error).
39- ///
40- /// [`axum::Extension`]: https://docs.rs/axum/latest/axum/extract/struct.Extension.html
41- pub fn new ( pool : sqlx:: Pool < DB > ) -> Self {
42- Self :: new_with_error ( pool)
43- }
44-
45- /// Construct a new layer with a specific error type.
46- ///
47- /// See [`Layer::new`] for more information.
48- pub fn new_with_error < E > ( pool : sqlx:: Pool < DB > ) -> Layer < DB , E > {
49- Layer {
50- pool,
28+ impl < DB : sqlx:: Database , E > Layer < DB , E >
29+ where
30+ E : IntoResponse ,
31+ sqlx:: Error : Into < E > ,
32+ {
33+ pub ( crate ) fn new ( state : State < DB > ) -> Self {
34+ Self {
35+ state,
5136 _error : PhantomData ,
5237 }
5338 }
@@ -56,18 +41,22 @@ impl<DB: sqlx::Database> Layer<DB> {
5641impl < DB : sqlx:: Database , E > Clone for Layer < DB , E > {
5742 fn clone ( & self ) -> Self {
5843 Self {
59- pool : self . pool . clone ( ) ,
44+ state : self . state . clone ( ) ,
6045 _error : self . _error ,
6146 }
6247 }
6348}
6449
65- impl < DB : sqlx:: Database , S , E > tower_layer:: Layer < S > for Layer < DB , E > {
50+ impl < DB : sqlx:: Database , S , E > tower_layer:: Layer < S > for Layer < DB , E >
51+ where
52+ E : IntoResponse ,
53+ sqlx:: Error : Into < E > ,
54+ {
6655 type Service = Service < DB , S , E > ;
6756
6857 fn layer ( & self , inner : S ) -> Self :: Service {
6958 Service {
70- pool : self . pool . clone ( ) ,
59+ state : self . state . clone ( ) ,
7160 inner,
7261 _error : self . _error ,
7362 }
@@ -77,8 +66,8 @@ impl<DB: sqlx::Database, S, E> tower_layer::Layer<S> for Layer<DB, E> {
7766/// A [`tower_service::Service`] that enables the [`Tx`](crate::Tx) extractor.
7867///
7968/// See [`Layer`] for more information.
80- pub struct Service < DB : sqlx:: Database , S , E = Error > {
81- pool : sqlx :: Pool < DB > ,
69+ pub struct Service < DB : sqlx:: Database , S , E > {
70+ state : State < DB > ,
8271 inner : S ,
8372 _error : PhantomData < E > ,
8473}
@@ -87,7 +76,7 @@ pub struct Service<DB: sqlx::Database, S, E = Error> {
8776impl < DB : sqlx:: Database , S : Clone , E > Clone for Service < DB , S , E > {
8877 fn clone ( & self ) -> Self {
8978 Self {
90- pool : self . pool . clone ( ) ,
79+ state : self . state . clone ( ) ,
9180 inner : self . inner . clone ( ) ,
9281 _error : self . _error ,
9382 }
10392 Error = std:: convert:: Infallible ,
10493 > ,
10594 S :: Future : Send + ' static ,
106- E : From < Error > + IntoResponse ,
95+ E : IntoResponse ,
96+ sqlx:: Error : Into < E > ,
10797 ResBody : Body < Data = Bytes > + Send + ' static ,
10898 ResBody :: Error : Into < Box < dyn std:: error:: Error + Send + Sync + ' static > > ,
10999{
@@ -119,7 +109,7 @@ where
119109 }
120110
121111 fn call ( & mut self , mut req : http:: Request < ReqBody > ) -> Self :: Future {
122- let transaction = TxSlot :: bind ( req. extensions_mut ( ) , self . pool . clone ( ) ) ;
112+ let transaction = TxSlot :: bind ( req. extensions_mut ( ) , self . state . clone ( ) ) ;
123113
124114 let res = self . inner . call ( req) ;
125115
@@ -128,7 +118,7 @@ where
128118
129119 if !res. status ( ) . is_server_error ( ) && !res. status ( ) . is_client_error ( ) {
130120 if let Err ( error) = transaction. commit ( ) . await {
131- return Ok ( E :: from ( Error :: Database { error } ) . into_response ( ) ) ;
121+ return Ok ( error. into ( ) . into_response ( ) ) ;
132122 }
133123 }
134124
@@ -139,17 +129,21 @@ where
139129
140130#[ cfg( test) ]
141131mod tests {
132+ use crate :: { Error , State } ;
133+
142134 use super :: Layer ;
143135
144136 // The trait shenanigans required by axum for layers are significant, so this "test" ensures
145137 // we've got it right.
146138 #[ allow( unused, unreachable_code, clippy:: diverging_sub_expression) ]
147139 fn layer_compiles ( ) {
148- let pool: sqlx:: Pool < sqlx:: Sqlite > = todo ! ( ) ;
140+ let state: State < sqlx:: Sqlite > = todo ! ( ) ;
141+
142+ let layer = Layer :: < _ , Error > :: new ( state) ;
149143
150144 let app = axum:: Router :: new ( )
151145 . route ( "/" , axum:: routing:: get ( || async { "hello" } ) )
152- . layer ( Layer :: new ( pool ) ) ;
146+ . layer ( layer ) ;
153147
154148 axum:: Server :: bind ( todo ! ( ) ) . serve ( app. into_make_service ( ) ) ;
155149 }
0 commit comments