Skip to content

Commit 489afe6

Browse files
authored
feat(csrf): Support async IsAllowedOriginHandler (#4558)
* support async IsAllowedOriginHandler * remove return type
1 parent f8c5a79 commit 489afe6

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed

src/middleware/csrf/index.test.ts

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,51 @@ describe('CSRF by Middleware', () => {
324324
expect(simplePostHandler).not.toHaveBeenCalled()
325325
})
326326
})
327+
328+
describe('async IsAllowedOriginHandler', () => {
329+
const app = new Hono()
330+
331+
app.use(
332+
'*',
333+
csrf({
334+
origin: async (origin) => {
335+
await new Promise((r) => setTimeout(r, 10))
336+
return /https:\/\/(\w+\.)?example\.com$/.test(origin)
337+
},
338+
})
339+
)
340+
app.post('/form', simplePostHandler)
341+
342+
it('should be 200 for allowed origin with async handler', async () => {
343+
let res = await app.request(
344+
'https://hono.example.com/form',
345+
buildSimplePostRequestData({ origin: 'https://hono.example.com' })
346+
)
347+
expect(res.status).toBe(200)
348+
349+
res = await app.request(
350+
'https://example.com/form',
351+
buildSimplePostRequestData({ origin: 'https://example.com' })
352+
)
353+
expect(res.status).toBe(200)
354+
})
355+
356+
it('should be 403 for not allowed origin with async handler', async () => {
357+
let res = await app.request(
358+
'http://honojs.hono.example.jp/form',
359+
buildSimplePostRequestData({ origin: 'http://example.jp' })
360+
)
361+
expect(res.status).toBe(403)
362+
expect(simplePostHandler).not.toHaveBeenCalled()
363+
364+
res = await app.request(
365+
'http://example.jp/form',
366+
buildSimplePostRequestData({ origin: 'http://example.jp' })
367+
)
368+
expect(res.status).toBe(403)
369+
expect(simplePostHandler).not.toHaveBeenCalled()
370+
})
371+
})
327372
})
328373

329374
describe('with secFetchSite option', () => {

src/middleware/csrf/index.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import type { Context } from '../../context'
77
import { HTTPException } from '../../http-exception'
88
import type { MiddlewareHandler } from '../../types'
99

10-
type IsAllowedOriginHandler = (origin: string, context: Context) => boolean
10+
type IsAllowedOriginHandler = (origin: string, context: Context) => boolean | Promise<boolean>
1111

1212
const secFetchSiteValues = ['same-origin', 'same-site', 'none', 'cross-site'] as const
1313
type SecFetchSite = (typeof secFetchSiteValues)[number]
@@ -100,12 +100,12 @@ export const csrf = (options?: CSRFOptions): MiddlewareHandler => {
100100
return (origin) => optsOrigin.includes(origin)
101101
}
102102
})(options?.origin)
103-
const isAllowedOrigin = (origin: string | undefined, c: Context) => {
103+
const isAllowedOrigin = async (origin: string | undefined, c: Context) => {
104104
if (origin === undefined) {
105105
// denied always when origin header is not present
106106
return false
107107
}
108-
return originHandler(origin, c)
108+
return await originHandler(origin, c)
109109
}
110110

111111
const secFetchSiteHandler: IsAllowedSecFetchSiteHandler = ((optsSecFetchSite) => {
@@ -137,7 +137,7 @@ export const csrf = (options?: CSRFOptions): MiddlewareHandler => {
137137
!isSafeMethodRe.test(c.req.method) &&
138138
isRequestedByFormElementRe.test(c.req.header('content-type') || 'text/plain') &&
139139
!isAllowedSecFetchSite(c.req.header('sec-fetch-site'), c) &&
140-
!isAllowedOrigin(c.req.header('origin'), c)
140+
!(await isAllowedOrigin(c.req.header('origin'), c))
141141
) {
142142
const res = new Response('Forbidden', { status: 403 })
143143
throw new HTTPException(403, { res })

0 commit comments

Comments
 (0)