diff --git a/benchmarks_test.go b/benchmarks_test.go index 5b7929b854..21c3d72574 100644 --- a/benchmarks_test.go +++ b/benchmarks_test.go @@ -121,6 +121,13 @@ func Benchmark404Many(B *testing.B) { runRequest(B, router, "GET", "/viewfake") } +func BenchmarkOnRedirect(B *testing.B) { + router := New() + router.GET("/something", func(c *Context) {}) + router.OnRedirect(func(c *Context) {}) + runRequest(B, router, "GET", "/something/") +} + type mockWriter struct { headers http.Header } diff --git a/gin.go b/gin.go index 24a9864aff..7cbba093d8 100644 --- a/gin.go +++ b/gin.go @@ -156,20 +156,22 @@ type Engine struct { // ContextWithFallback enable fallback Context.Deadline(), Context.Done(), Context.Err() and Context.Value() when Context.Request.Context() is not nil. ContextWithFallback bool - delims render.Delims - secureJSONPrefix string - HTMLRender render.HTMLRender - FuncMap template.FuncMap - allNoRoute HandlersChain - allNoMethod HandlersChain - noRoute HandlersChain - noMethod HandlersChain - pool sync.Pool - trees methodTrees - maxParams uint16 - maxSections uint16 - trustedProxies []string - trustedCIDRs []*net.IPNet + delims render.Delims + secureJSONPrefix string + HTMLRender render.HTMLRender + FuncMap template.FuncMap + allNoRoute HandlersChain + allNoMethod HandlersChain + allRedirectMethod HandlersChain + noRoute HandlersChain + noMethod HandlersChain + redirectMethod HandlersChain + pool sync.Pool + trees methodTrees + maxParams uint16 + maxSections uint16 + trustedProxies []string + trustedCIDRs []*net.IPNet } var _ IRouter = (*Engine)(nil) @@ -297,6 +299,12 @@ func (engine *Engine) NoRoute(handlers ...HandlerFunc) { engine.rebuild404Handlers() } +// OnRedirect adds handlers for when redirects are made. +func (engine *Engine) OnRedirect(handlers ...HandlerFunc) { + engine.redirectMethod = handlers + engine.rebuildRedirectHandlers() +} + // NoMethod sets the handlers called when Engine.HandleMethodNotAllowed = true. func (engine *Engine) NoMethod(handlers ...HandlerFunc) { engine.noMethod = handlers @@ -310,6 +318,7 @@ func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes { engine.RouterGroup.Use(middleware...) engine.rebuild404Handlers() engine.rebuild405Handlers() + engine.rebuildRedirectHandlers() return engine } @@ -321,6 +330,10 @@ func (engine *Engine) rebuild405Handlers() { engine.allNoMethod = engine.combineHandlers(engine.noMethod) } +func (engine *Engine) rebuildRedirectHandlers() { + engine.allRedirectMethod = engine.combineHandlers(engine.redirectMethod) +} + func (engine *Engine) addRoute(method, path string, handlers HandlersChain) { assert1(path[0] == '/', "path must begin with '/'") assert1(method != "", "HTTP method can not be empty") @@ -623,12 +636,17 @@ func (engine *Engine) handleHTTPRequest(c *Context) { return } if httpMethod != http.MethodConnect && rPath != "/" { + executeRedirectionMiddlewares := len(engine.redirectMethod) > 0 if value.tsr && engine.RedirectTrailingSlash { - redirectTrailingSlash(c) + c.handlers = engine.allRedirectMethod + redirectTrailingSlash(c, executeRedirectionMiddlewares) return } - if engine.RedirectFixedPath && redirectFixedPath(c, root, engine.RedirectFixedPath) { - return + if engine.RedirectFixedPath { + c.handlers = engine.allRedirectMethod + if redirectFixedPath(c, root, engine.RedirectFixedPath, executeRedirectionMiddlewares) { + return + } } } break @@ -677,7 +695,7 @@ func serveError(c *Context, code int, defaultMessage []byte) { c.writermem.WriteHeaderNow() } -func redirectTrailingSlash(c *Context) { +func redirectTrailingSlash(c *Context, withRedirectionMiddlewares bool) { req := c.Request p := req.URL.Path if prefix := path.Clean(c.Request.Header.Get("X-Forwarded-Prefix")); prefix != "." { @@ -690,22 +708,22 @@ func redirectTrailingSlash(c *Context) { if length := len(p); length > 1 && p[length-1] == '/' { req.URL.Path = p[:length-1] } - redirectRequest(c) + redirectRequest(c, withRedirectionMiddlewares) } -func redirectFixedPath(c *Context, root *node, trailingSlash bool) bool { +func redirectFixedPath(c *Context, root *node, trailingSlash, withRedirectionMiddlewares bool) bool { req := c.Request rPath := req.URL.Path if fixedPath, ok := root.findCaseInsensitivePath(cleanPath(rPath), trailingSlash); ok { req.URL.Path = bytesconv.BytesToString(fixedPath) - redirectRequest(c) + redirectRequest(c, withRedirectionMiddlewares) return true } return false } -func redirectRequest(c *Context) { +func redirectRequest(c *Context, executeRequestChain bool) { req := c.Request rPath := req.URL.Path rURL := req.URL.String() @@ -715,6 +733,9 @@ func redirectRequest(c *Context) { code = http.StatusTemporaryRedirect } debugPrint("redirecting request %d: %s --> %s", code, rPath, rURL) + if executeRequestChain { + c.Next() + } http.Redirect(c.Writer, req, rURL, code) c.writermem.WriteHeaderNow() } diff --git a/ginS/gins.go b/ginS/gins.go index ea38c613ce..1de2a7ff8c 100644 --- a/ginS/gins.go +++ b/ginS/gins.go @@ -42,6 +42,11 @@ func NoRoute(handlers ...gin.HandlerFunc) { engine().NoRoute(handlers...) } +// OnRedirect is a wrapper for Engine.OnRedirect. +func OnRedirect(handlers ...gin.HandlerFunc) { + engine().OnRedirect(handlers...) +} + // NoMethod is a wrapper for Engine.NoMethod. func NoMethod(handlers ...gin.HandlerFunc) { engine().NoMethod(handlers...) diff --git a/gin_test.go b/gin_test.go index 8825ac7ef8..3f1e6a01eb 100644 --- a/gin_test.go +++ b/gin_test.go @@ -433,6 +433,59 @@ func TestNoMethodWithoutGlobalHandlers(t *testing.T) { compareFunc(t, router.allNoMethod[1], middleware0) } +func TestOnRedirectWithoutGlobalHandlers(t *testing.T) { + var middleware0 HandlerFunc = func(c *Context) {} + var middleware1 HandlerFunc = func(c *Context) {} + + router := New() + + router.OnRedirect(middleware0) + assert.Nil(t, router.Handlers) + assert.Len(t, router.redirectMethod, 1) + assert.Len(t, router.allRedirectMethod, 1) + compareFunc(t, router.redirectMethod[0], middleware0) + compareFunc(t, router.allRedirectMethod[0], middleware0) + + router.OnRedirect(middleware1, middleware0) + assert.Len(t, router.redirectMethod, 2) + assert.Len(t, router.allRedirectMethod, 2) + compareFunc(t, router.redirectMethod[0], middleware1) + compareFunc(t, router.allRedirectMethod[0], middleware1) + compareFunc(t, router.redirectMethod[1], middleware0) + compareFunc(t, router.allRedirectMethod[1], middleware0) +} + +func TestOnRedirectWithGlobalHandlers(t *testing.T) { + var middleware0 HandlerFunc = func(c *Context) {} + var middleware1 HandlerFunc = func(c *Context) {} + var middleware2 HandlerFunc = func(c *Context) {} + + router := New() + router.Use(middleware2) + + router.OnRedirect(middleware0) + assert.Len(t, router.allRedirectMethod, 2) + assert.Len(t, router.Handlers, 1) + assert.Len(t, router.redirectMethod, 1) + + compareFunc(t, router.Handlers[0], middleware2) + compareFunc(t, router.redirectMethod[0], middleware0) + compareFunc(t, router.allRedirectMethod[0], middleware2) + compareFunc(t, router.allRedirectMethod[1], middleware0) + + router.Use(middleware1) + assert.Len(t, router.allRedirectMethod, 3) + assert.Len(t, router.Handlers, 2) + assert.Len(t, router.redirectMethod, 1) + + compareFunc(t, router.Handlers[0], middleware2) + compareFunc(t, router.Handlers[1], middleware1) + compareFunc(t, router.redirectMethod[0], middleware0) + compareFunc(t, router.allRedirectMethod[0], middleware2) + compareFunc(t, router.allRedirectMethod[1], middleware1) + compareFunc(t, router.allRedirectMethod[2], middleware0) +} + func TestRebuild404Handlers(t *testing.T) { } diff --git a/go.mod b/go.mod index fbbce7c0fc..9c7ecaf555 100644 --- a/go.mod +++ b/go.mod @@ -32,6 +32,8 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect golang.org/x/arch v0.7.0 // indirect golang.org/x/crypto v0.19.0 // indirect + golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 // indirect golang.org/x/sys v0.17.0 // indirect golang.org/x/text v0.14.0 // indirect + golang.org/x/tools v0.18.0 // indirect ) diff --git a/go.sum b/go.sum index ce6c7fe703..bb68158d46 100644 --- a/go.sum +++ b/go.sum @@ -63,16 +63,31 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.7.0 h1:pskyeJh/3AmoQ8CPE95vxHLqp1G1GfGNXTmcl9NEKTc= golang.org/x/arch v0.7.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 h1:VLliZ0d+/avPrXXH+OakdXhpJuEoBZuwh1m2j7U6Iug= +golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.18.0 h1:k8NLag8AGHnn+PHbl7g43CtqZAwG60vZkLqgyZgIHgQ= +golang.org/x/tools v0.18.0/go.mod h1:GL7B4CwcLLeo59yx/9UWWuNOW1n3VZ4f5axWfML7Lcg= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= google.golang.org/protobuf v1.32.0 h1:pPC6BG5ex8PDFnkbrGU3EixyhKcQ2aDuBS36lqK/C7I= google.golang.org/protobuf v1.32.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= diff --git a/middleware_test.go b/middleware_test.go index acdf89c42e..30a72ceee0 100644 --- a/middleware_test.go +++ b/middleware_test.go @@ -34,6 +34,9 @@ func TestMiddlewareGeneralCase(t *testing.T) { router.NoMethod(func(c *Context) { signature += " XX " }) + router.OnRedirect(func(c *Context) { + signature += " YYY " + }) // RUN w := PerformRequest(router, "GET", "/") @@ -78,6 +81,50 @@ func TestMiddlewareNoRoute(t *testing.T) { assert.Equal(t, "ACEGHFDB", signature) } +func TestMiddlewareOnRedirect(t *testing.T) { + signature := "" + router := New() + router.Use(func(c *Context) { + signature += "A" + c.Next() + signature += "B" + }) + router.Use(func(c *Context) { + signature += "C" + c.Next() + c.Next() + c.Next() + c.Next() + signature += "D" + }) + router.NoRoute(func(c *Context) { + signature += "E" + c.Next() + signature += "F" + }, func(c *Context) { + signature += "G" + c.Next() + signature += "H" + }) + router.NoMethod(func(c *Context) { + signature += " X " + }) + router.OnRedirect(func(c *Context) { + signature += "Y" + }) + + router.GET("/foo", func(c *Context) { + c.String(200, "Hello, World!") + }) + + // RUN + w := PerformRequest(router, "GET", "/foo/") + + // TEST + assert.Equal(t, http.StatusMovedPermanently, w.Code) + assert.Equal(t, "ACYDB", signature) +} + func TestMiddlewareNoMethodEnabled(t *testing.T) { signature := "" router := New()