diff --git a/middleware/redirect.go b/middleware/redirect.go index 6017ddbb8..a6175a79d 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -6,29 +6,28 @@ import ( "github.com/labstack/echo" ) -type ( - // RedirectConfig defines the config for Redirect middleware. - RedirectConfig struct { - // Skipper defines a function to skip middleware. - Skipper Skipper - - // Status code to be used when redirecting the request. - // Optional. Default value http.StatusMovedPermanently. - Code int `yaml:"code"` - } -) +// RedirectConfig defines the config for Redirect middleware. +type RedirectConfig struct { + // Skipper defines a function to skip middleware. + Skipper + + // Status code to be used when redirecting the request. + // Optional. Default value http.StatusMovedPermanently. + Code int `yaml:"code"` +} -const ( - www = "www" -) +// redirectLogic represents a function that given a tls flag, host and uri +// can both: 1) determine if redirect is needed (will set ok accordingly) and +// 2) return the appropriate redirect url. +type redirectLogic func(tls bool, scheme, host, uri string) (ok bool, url string) -var ( - // DefaultRedirectConfig is the default Redirect middleware config. - DefaultRedirectConfig = RedirectConfig{ - Skipper: DefaultSkipper, - Code: http.StatusMovedPermanently, - } -) +const www = "www" + +// DefaultRedirectConfig is the default Redirect middleware config. +var DefaultRedirectConfig = RedirectConfig{ + Skipper: DefaultSkipper, + Code: http.StatusMovedPermanently, +} // HTTPSRedirect redirects http requests to https. // For example, http://labstack.com will be redirect to https://labstack.com. @@ -41,29 +40,12 @@ func HTTPSRedirect() echo.MiddlewareFunc { // HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. // See `HTTPSRedirect()`. func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - // Defaults - if config.Skipper == nil { - config.Skipper = DefaultTrailingSlashConfig.Skipper - } - if config.Code == 0 { - config.Code = DefaultRedirectConfig.Code - } - - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - if config.Skipper(c) { - return next(c) - } - - req := c.Request() - host := req.Host - uri := req.RequestURI - if !c.IsTLS() { - return c.Redirect(config.Code, "https://"+host+uri) - } - return next(c) + return redirect(config, func(isTLS bool, _, host, uri string) (ok bool, url string) { + if ok = !isTLS; ok { + url = "https://" + host + uri } - } + return + }) } // HTTPSWWWRedirect redirects http requests to https www. @@ -77,29 +59,12 @@ func HTTPSWWWRedirect() echo.MiddlewareFunc { // HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // See `HTTPSWWWRedirect()`. func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - // Defaults - if config.Skipper == nil { - config.Skipper = DefaultTrailingSlashConfig.Skipper - } - if config.Code == 0 { - config.Code = DefaultRedirectConfig.Code - } - - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - if config.Skipper(c) { - return next(c) - } - - req := c.Request() - host := req.Host - uri := req.RequestURI - if !c.IsTLS() && host[:3] != www { - return c.Redirect(config.Code, "https://www."+host+uri) - } - return next(c) + return redirect(config, func(isTLS bool, _, host, uri string) (ok bool, url string) { + if ok = !isTLS && host[:3] != www; ok { + url = "https://www." + host + uri } - } + return + }) } // HTTPSNonWWWRedirect redirects http requests to https non www. @@ -113,32 +78,15 @@ func HTTPSNonWWWRedirect() echo.MiddlewareFunc { // HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // See `HTTPSNonWWWRedirect()`. func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - // Defaults - if config.Skipper == nil { - config.Skipper = DefaultTrailingSlashConfig.Skipper - } - if config.Code == 0 { - config.Code = DefaultRedirectConfig.Code - } - - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - if config.Skipper(c) { - return next(c) - } - - req := c.Request() - host := req.Host - uri := req.RequestURI - if !c.IsTLS() { - if host[:3] == www { - return c.Redirect(config.Code, "https://"+host[4:]+uri) - } - return c.Redirect(config.Code, "https://"+host+uri) + return redirect(config, func(isTLS bool, _, host, uri string) (ok bool, url string) { + if ok = !isTLS; ok { + if host[:3] == www { + host = host[4:] } - return next(c) + url = "https://" + host + uri } - } + return + }) } // WWWRedirect redirects non www requests to www. @@ -152,30 +100,12 @@ func WWWRedirect() echo.MiddlewareFunc { // WWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // See `WWWRedirect()`. func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - // Defaults - if config.Skipper == nil { - config.Skipper = DefaultTrailingSlashConfig.Skipper - } - if config.Code == 0 { - config.Code = DefaultRedirectConfig.Code - } - - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - if config.Skipper(c) { - return next(c) - } - - req := c.Request() - scheme := c.Scheme() - host := req.Host - if host[:3] != www { - uri := req.RequestURI - return c.Redirect(config.Code, scheme+"://www."+host+uri) - } - return next(c) + return redirect(config, func(_ bool, scheme, host, uri string) (ok bool, url string) { + if ok = host[:3] != www; ok { + url = scheme + "://www." + host + uri } - } + return + }) } // NonWWWRedirect redirects www requests to non www. @@ -189,6 +119,15 @@ func NonWWWRedirect() echo.MiddlewareFunc { // NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. // See `NonWWWRedirect()`. func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { + return redirect(config, func(isTLS bool, scheme, host, uri string) (ok bool, url string) { + if ok = host[:3] == www; ok { + url = scheme + "://" + host[4:] + uri + } + return + }) +} + +func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { if config.Skipper == nil { config.Skipper = DefaultTrailingSlashConfig.Skipper } @@ -202,13 +141,12 @@ func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { return next(c) } - req := c.Request() - scheme := c.Scheme() + req, scheme := c.Request(), c.Scheme() host := req.Host - if host[:3] == www { - uri := req.RequestURI - return c.Redirect(config.Code, scheme+"://"+host[4:]+uri) + if ok, url := cb(c.IsTLS(), scheme, host, req.RequestURI); ok { + return c.Redirect(config.Code, url) } + return next(c) } }