1
- use anyhow:: { anyhow , Context , Result } ;
1
+ use anyhow:: { bail , Context , Result } ;
2
2
use log:: { trace, warn} ;
3
3
use std:: { collections:: HashMap , fs, str:: FromStr } ;
4
4
5
5
use crate :: {
6
6
file_server:: FileServer ,
7
- http:: { HttpMethod , HttpRequest , HttpResponse , HttpResponseBuilder } ,
7
+ http:: {
8
+ response_status_codes:: HttpStatusCode , HttpMethod , HttpRequest , HttpResponse ,
9
+ HttpResponseBuilder ,
10
+ } ,
8
11
} ;
9
12
10
- type RoutingCallback = fn ( & HttpRequest ) -> Result < HttpResponse > ;
11
-
12
13
#[ derive( Debug ) ]
13
14
pub struct Router {
14
- pub routes : HashMap < Route , RoutingCallback > ,
15
+ pub routes : HashMap < StoredRoute , RoutingCallback > ,
16
+ pub catcher_routes : HashMap < HttpMethod , RoutingCallback > ,
15
17
pub file_server : Option < FileServer > ,
16
18
}
17
19
@@ -25,6 +27,7 @@ impl Router {
25
27
pub fn new ( ) -> Self {
26
28
Router {
27
29
routes : HashMap :: new ( ) ,
30
+ catcher_routes : HashMap :: new ( ) ,
28
31
file_server : None ,
29
32
}
30
33
}
@@ -34,38 +37,91 @@ impl Router {
34
37
self
35
38
}
36
39
40
+ fn find_matching_route ( & self , route : & RequestRoute ) -> Result < Option < & StoredRoute > > {
41
+ let mut excluded: Vec < & StoredRoute > = vec ! [ ] ;
42
+ let request_route_parts = route. path . split ( '/' ) ;
43
+
44
+ for ( idx, part) in request_route_parts. enumerate ( ) {
45
+ for match_candidate in self . routes . keys ( ) {
46
+ if excluded. contains ( & match_candidate) {
47
+ continue ;
48
+ } ;
49
+
50
+ if let Some ( match_part) = match_candidate. parts . get ( idx) {
51
+ if !match_part. is_dynamic && !match_part. value . eq ( part) {
52
+ excluded. push ( match_candidate) ;
53
+ }
54
+ } else {
55
+ excluded. push ( match_candidate) ;
56
+ } ;
57
+ }
58
+ }
59
+
60
+ let selected_routes: Vec < & StoredRoute > = self
61
+ . routes
62
+ . keys ( )
63
+ . filter ( |route| !excluded. contains ( route) )
64
+ . collect ( ) ;
65
+
66
+ match selected_routes. len ( ) {
67
+ 0 => Ok ( None ) ,
68
+ 1 => Ok ( Some ( selected_routes. first ( ) . unwrap ( ) ) ) ,
69
+ _ => bail ! ( "multiple selected routes is not possible" ) ,
70
+ }
71
+ }
72
+
37
73
pub fn handle_request ( & self , request : & HttpRequest ) -> Result < HttpResponse > {
38
74
let route_def = format ! ( "{} {}" , request. method, request. url) ;
39
- let route = Route :: from_str ( & route_def) ?;
75
+ let route = RequestRoute :: from_str ( & route_def) ?;
40
76
trace ! ( "trying to match route: {route_def}" ) ;
41
77
42
- let response = if let Some ( route_callback) = self . routes . get ( & route) {
43
- route_callback ( request)
44
- } else {
45
- if let Some ( file_server) = & self . file_server {
46
- match file_server. handle_file_access ( & route. path ) {
47
- Ok ( file_path) => {
48
- let mime_type = mime_guess:: from_path ( & file_path) . first_or_octet_stream ( ) ;
49
- let content = fs:: read ( file_path) ?;
50
-
51
- return HttpResponseBuilder :: new ( )
52
- . set_raw_body ( content)
53
- . set_content_type ( mime_type. as_ref ( ) )
54
- . build ( ) ;
55
- }
56
- Err ( e) => warn ! ( "failed to match file: {e}" ) ,
78
+ // test against declared routes
79
+ if let Ok ( Some ( matching_route) ) = self . find_matching_route ( & route) {
80
+ let routing_data = matching_route. extract_routing_data ( & request. url ) ?;
81
+ let callback = self . routes . get ( matching_route) . context ( "expected route" ) ?;
82
+ return callback ( request, & routing_data) ;
83
+ }
84
+
85
+ // test against file server static mappings
86
+ if let Some ( file_server) = & self . file_server {
87
+ match file_server. handle_file_access ( & route. path ) {
88
+ Ok ( file_path) => {
89
+ let mime_type = mime_guess:: from_path ( & file_path) . first_or_octet_stream ( ) ;
90
+ let content = fs:: read ( file_path) ?;
91
+
92
+ return HttpResponseBuilder :: new ( )
93
+ . set_raw_body ( content)
94
+ . set_content_type ( mime_type. as_ref ( ) )
95
+ . build ( ) ;
57
96
}
97
+ Err ( e) => warn ! ( "failed to match file: {e}" ) ,
58
98
}
99
+ }
59
100
60
- let catch_all_route = Route :: from_str ( "GET /*" ) ? ;
61
- if let Some ( catch_all_callback ) = self . routes . get ( & catch_all_route ) {
62
- return catch_all_callback ( request) ;
63
- }
101
+ // test against catcher routes
102
+ if let Some ( catcher ) = self . catcher_routes . get ( & request . method ) {
103
+ return catcher ( request, & RoutingData :: default ( ) ) ;
104
+ }
64
105
65
- Err ( anyhow ! ( "failed to match route: {route_def}" ) )
66
- } ;
106
+ HttpResponseBuilder :: new ( )
107
+ . set_status ( HttpStatusCode :: NotFound )
108
+ . build ( )
109
+ }
110
+
111
+ pub fn add_catcher_route (
112
+ & mut self ,
113
+ method : HttpMethod ,
114
+ callback : RoutingCallback ,
115
+ ) -> Result < ( ) > {
116
+ if self . catcher_routes . contains_key ( & method) {
117
+ bail ! (
118
+ "cannot register catcher because one already exists for: {}" ,
119
+ method. to_string( )
120
+ ) ;
121
+ }
67
122
68
- response
123
+ self . catcher_routes . insert ( method, callback) ;
124
+ Ok ( ( ) )
69
125
}
70
126
71
127
pub fn add_route (
@@ -74,13 +130,13 @@ impl Router {
74
130
path : & str ,
75
131
callback : RoutingCallback ,
76
132
) -> Result < ( ) > {
77
- let route = Route :: new ( method, & path) ;
133
+ let route = StoredRoute :: new ( method, path) ? ;
78
134
79
135
if self . routes . contains_key ( & route) {
80
- return Err ( anyhow ! (
136
+ bail ! (
81
137
"cannot register route {:?} because a similar route already exists" ,
82
138
route
83
- ) ) ;
139
+ ) ;
84
140
}
85
141
86
142
self . routes . insert ( route, callback) ;
@@ -131,53 +187,151 @@ impl Router {
131
187
self . add_route ( HttpMethod :: PATCH , path, callback) ?;
132
188
Ok ( self )
133
189
}
190
+
191
+ pub fn catch_all ( mut self , method : HttpMethod , callback : RoutingCallback ) -> Result < Self > {
192
+ self . add_catcher_route ( method, callback) ?;
193
+ Ok ( self )
194
+ }
134
195
}
135
196
136
197
#[ derive( Debug , Hash , Eq , PartialEq , Clone ) ]
137
- pub struct Route {
198
+ pub struct StoredRoute {
138
199
pub method : HttpMethod ,
139
200
pub path : String ,
201
+ pub parts : Vec < RoutePart > ,
140
202
}
141
203
142
- impl Route {
143
- pub fn new ( method : HttpMethod , path : & str ) -> Route {
204
+ impl StoredRoute {
205
+ pub fn new ( method : HttpMethod , path : & str ) -> Result < Self > {
144
206
let path = path. trim_matches ( '/' ) . to_owned ( ) ;
145
- Route { method, path }
207
+
208
+ let mut parts = vec ! [ ] ;
209
+ for part in path. split ( '/' ) {
210
+ let is_dynamic = part. starts_with ( ':' ) ;
211
+ let value = if is_dynamic {
212
+ part[ 1 ..] . to_string ( )
213
+ } else {
214
+ part. to_string ( )
215
+ } ;
216
+
217
+ if value. contains ( ':' ) {
218
+ bail ! ( "nested `:` is not allowed in dynamic route part" ) ;
219
+ }
220
+
221
+ parts. push ( RoutePart { is_dynamic, value } ) ;
222
+ }
223
+
224
+ Ok ( Self {
225
+ method,
226
+ path,
227
+ parts,
228
+ } )
229
+ }
230
+
231
+ pub fn extract_routing_data ( & self , request_url : & str ) -> Result < RoutingData > {
232
+ let request_parts: Vec < _ > = request_url. split ( '/' ) . filter ( |p| !p. is_empty ( ) ) . collect ( ) ;
233
+
234
+ let mut params: HashMap < String , String > = HashMap :: new ( ) ;
235
+ for ( idx, part) in self . parts . iter ( ) . enumerate ( ) {
236
+ if !part. is_dynamic {
237
+ continue ;
238
+ }
239
+
240
+ let part_value = * request_parts
241
+ . get ( idx)
242
+ . context ( "part `{part_idx}` should exist" ) ?;
243
+
244
+ params. insert ( part. value . to_owned ( ) , part_value. to_owned ( ) ) ;
245
+ }
246
+
247
+ Ok ( RoutingData { params } )
146
248
}
147
249
}
148
250
149
- impl FromStr for Route {
251
+ #[ derive( Debug , Hash , Eq , PartialEq , Clone ) ]
252
+ pub struct RoutePart {
253
+ pub is_dynamic : bool ,
254
+ pub value : String ,
255
+ }
256
+
257
+ #[ derive( Debug , Hash , Eq , PartialEq , Clone ) ]
258
+ pub struct RequestRoute {
259
+ pub method : HttpMethod ,
260
+ pub path : String ,
261
+ }
262
+
263
+ impl RequestRoute {
264
+ pub fn new ( method : HttpMethod , path : & str ) -> RequestRoute {
265
+ let path = path. trim_matches ( '/' ) . to_owned ( ) ;
266
+ RequestRoute { method, path }
267
+ }
268
+ }
269
+
270
+ impl FromStr for RequestRoute {
150
271
type Err = anyhow:: Error ;
151
272
152
273
fn from_str ( s : & str ) -> std:: result:: Result < Self , Self :: Err > {
153
- let ( method, path) = s. split_once ( " " ) . context ( "route should have: VERB PATH" ) ?;
274
+ let ( method, path) = s
275
+ . split_once ( " " )
276
+ . context ( "route should have following format: METHOD PATH (ex: GET /index)" ) ?;
154
277
let method = HttpMethod :: from_str ( method) ?;
155
278
156
- Ok ( Route :: new ( method, path) )
279
+ Ok ( RequestRoute :: new ( method, path) )
157
280
}
158
281
}
159
282
283
+ type RoutingCallback = fn ( & HttpRequest , & RoutingData ) -> Result < HttpResponse > ;
284
+
285
+ #[ derive( Debug , Default ) ]
286
+ pub struct RoutingData {
287
+ pub params : HashMap < String , String > ,
288
+ }
289
+
160
290
#[ cfg( test) ]
161
291
mod tests {
162
- use serde_json:: json;
292
+ use serde_json:: { json, Value } ;
163
293
164
294
use crate :: http:: { HttpRequestRaw , HttpResponseBuilder } ;
165
295
166
296
use super :: * ;
167
297
168
- fn get_hello_callback ( _request : & HttpRequest ) -> Result < HttpResponse > {
298
+ fn get_hello_callback (
299
+ _request : & HttpRequest ,
300
+ _routing_data : & RoutingData ,
301
+ ) -> Result < HttpResponse > {
169
302
HttpResponseBuilder :: new ( )
170
303
. set_html_body ( "Hello World!" )
171
304
. build ( )
172
305
}
173
306
174
- fn post_user_callback ( _request : & HttpRequest ) -> Result < HttpResponse > {
307
+ fn get_user_by_id ( _request : & HttpRequest , routing_data : & RoutingData ) -> Result < HttpResponse > {
308
+ let id = routing_data. params . get ( "id" ) . unwrap ( ) ;
309
+ let username = format ! ( "user_{id}" ) ;
310
+ let json = json ! ( { "username" : username } ) ;
311
+
312
+ HttpResponseBuilder :: new ( ) . set_json_body ( & json) ?. build ( )
313
+ }
314
+
315
+ fn get_user_info ( _request : & HttpRequest , routing_data : & RoutingData ) -> Result < HttpResponse > {
316
+ let id = routing_data. params . get ( "id" ) . unwrap ( ) ;
317
+ let info_field = routing_data. params . get ( "field" ) . unwrap ( ) ;
318
+
319
+ let username = format ! ( "user_{id}" ) ;
320
+ let json = json ! ( { "username" : username, "field" : info_field } ) ;
321
+
322
+ HttpResponseBuilder :: new ( ) . set_json_body ( & json) ?. build ( )
323
+ }
324
+
325
+ fn post_user_callback (
326
+ _request : & HttpRequest ,
327
+ _routing_data : & RoutingData ,
328
+ ) -> Result < HttpResponse > {
175
329
let json = json ! ( { "created" : true } ) ;
176
330
HttpResponseBuilder :: new ( ) . set_json_body ( & json) ?. build ( )
177
331
}
178
332
179
333
#[ test]
180
- fn test_unknown_route_err ( ) {
334
+ fn test_unmatched_no_catcher ( ) {
181
335
let router = Router :: new ( ) ;
182
336
183
337
let request = HttpRequest :: from_raw_request ( HttpRequestRaw {
@@ -187,13 +341,15 @@ mod tests {
187
341
} )
188
342
. unwrap ( ) ;
189
343
190
- let response = router. handle_request ( & request) ;
191
- assert ! ( response . is_err ( ) ) ;
344
+ let response = router. handle_request ( & request) . unwrap ( ) ;
345
+ assert_eq ! ( HttpStatusCode :: NotFound . to_string ( ) , response . status ) ;
192
346
}
193
347
194
348
#[ test]
195
- fn test_unknown_has_fallback ( ) {
196
- let router = Router :: new ( ) . get ( "/*" , get_hello_callback) . unwrap ( ) ;
349
+ fn test_unmatched_get_catcher ( ) {
350
+ let router = Router :: new ( )
351
+ . catch_all ( HttpMethod :: GET , get_hello_callback)
352
+ . unwrap ( ) ;
197
353
198
354
let request = HttpRequest :: from_raw_request ( HttpRequestRaw {
199
355
request_line : "GET /not-a-real-page HTTP/1.1" . to_owned ( ) ,
@@ -235,4 +391,41 @@ mod tests {
235
391
let response = router. handle_request ( & request) . unwrap ( ) ;
236
392
assert_eq ! ( "{\" created\" :true}\r \n " . as_bytes( ) , response. body) ;
237
393
}
394
+
395
+ #[ test]
396
+ fn test_dynamic_route ( ) {
397
+ let router = Router :: new ( )
398
+ . get ( "/users/:id/details" , get_user_by_id)
399
+ . unwrap ( ) ;
400
+
401
+ let request = HttpRequest :: from_raw_request ( HttpRequestRaw {
402
+ request_line : "GET /users/5/details HTTP/1.1" . to_owned ( ) ,
403
+ headers : Vec :: new ( ) ,
404
+ body : vec ! [ ] ,
405
+ } )
406
+ . unwrap ( ) ;
407
+
408
+ let response = router. handle_request ( & request) . unwrap ( ) ;
409
+ let actual_res: Value = serde_json:: from_slice ( & response. body ) . unwrap ( ) ;
410
+ assert_eq ! ( "user_5" , actual_res[ "username" ] ) ;
411
+ }
412
+
413
+ #[ test]
414
+ fn test_dynamic_route_multiparams ( ) {
415
+ let router = Router :: new ( )
416
+ . get ( "/users/:id/info/:field" , get_user_info)
417
+ . unwrap ( ) ;
418
+
419
+ let request = HttpRequest :: from_raw_request ( HttpRequestRaw {
420
+ request_line : "GET /users/17/info/gender HTTP/1.1" . to_owned ( ) ,
421
+ headers : Vec :: new ( ) ,
422
+ body : vec ! [ ] ,
423
+ } )
424
+ . unwrap ( ) ;
425
+
426
+ let response = router. handle_request ( & request) . unwrap ( ) ;
427
+ let actual_res: Value = serde_json:: from_slice ( & response. body ) . unwrap ( ) ;
428
+ let expected_result = json ! ( { "username" : "user_17" , "field" : "gender" } ) ;
429
+ assert_eq ! ( expected_result, actual_res) ;
430
+ }
238
431
}
0 commit comments