@@ -14,7 +14,7 @@ import (
1414 "github.com/klauspost/compress/zstd"
1515)
1616
17- // Transport will wrap a transport with a custom handler
17+ // Transport will wrap an HTTP transport with a custom handler
1818// that will request gzip and automatically decompress it.
1919// Using this is significantly faster than using the default transport.
2020func Transport (parent http.RoundTripper , opts ... transportOption ) http.RoundTripper {
@@ -51,10 +51,21 @@ func TransportEnableGzip(b bool) transportOption {
5151 }
5252}
5353
54+ // TransportCustomEval will send the header of a response to a custom function.
55+ // If the function returns false, the response will be returned as-is,
56+ // Otherwise it will be decompressed based on Content-Encoding field, regardless
57+ // of whether the transport added the encoding.
58+ func TransportCustomEval (fn func (header http.Header ) bool ) transportOption {
59+ return func (c * gzRoundtripper ) {
60+ c .customEval = fn
61+ }
62+ }
63+
5464type gzRoundtripper struct {
5565 parent http.RoundTripper
5666 acceptEncoding string
5767 withZstd , withGzip bool
68+ customEval func (header http.Header ) bool
5869}
5970
6071func (g * gzRoundtripper ) RoundTrip (req * http.Request ) (* http.Response , error ) {
@@ -82,16 +93,22 @@ func (g *gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
8293 if err != nil || ! requestedComp {
8394 return resp , err
8495 }
85-
96+ decompress := false
97+ if g .customEval != nil {
98+ if ! g .customEval (resp .Header ) {
99+ return resp , nil
100+ }
101+ decompress = true
102+ }
86103 // Decompress
87- if g .withGzip && asciiEqualFold (resp .Header .Get ("Content-Encoding" ), "gzip" ) {
104+ if ( decompress || g .withGzip ) && asciiEqualFold (resp .Header .Get ("Content-Encoding" ), "gzip" ) {
88105 resp .Body = & gzipReader {body : resp .Body }
89106 resp .Header .Del ("Content-Encoding" )
90107 resp .Header .Del ("Content-Length" )
91108 resp .ContentLength = - 1
92109 resp .Uncompressed = true
93110 }
94- if g .withZstd && asciiEqualFold (resp .Header .Get ("Content-Encoding" ), "zstd" ) {
111+ if ( decompress || g .withZstd ) && asciiEqualFold (resp .Header .Get ("Content-Encoding" ), "zstd" ) {
95112 resp .Body = & zstdReader {body : resp .Body }
96113 resp .Header .Del ("Content-Encoding" )
97114 resp .Header .Del ("Content-Length" )
0 commit comments