Skip to content

Commit 98ff542

Browse files
authored
gzhttp: Allow overriding decompression on transport (#892)
This allows getting compressed data even if `Content-Encoding` is set. Also allows decompression even if "Accept-Encoding" was not set by this client.
1 parent c63a492 commit 98ff542

File tree

2 files changed

+66
-4
lines changed

2 files changed

+66
-4
lines changed

gzhttp/transport.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import (
1414
"github.com/klauspost/compress/zstd"
1515
)
1616

17-
// Transport will wrap a transport with a custom handler
17+
// Transport will wrap an HTTP transport with a custom handler
1818
// that will request gzip and automatically decompress it.
1919
// Using this is significantly faster than using the default transport.
2020
func Transport(parent http.RoundTripper, opts ...transportOption) http.RoundTripper {
@@ -51,10 +51,21 @@ func TransportEnableGzip(b bool) transportOption {
5151
}
5252
}
5353

54+
// TransportCustomEval will send the header of a response to a custom function.
55+
// If the function returns false, the response will be returned as-is,
56+
// Otherwise it will be decompressed based on Content-Encoding field, regardless
57+
// of whether the transport added the encoding.
58+
func TransportCustomEval(fn func(header http.Header) bool) transportOption {
59+
return func(c *gzRoundtripper) {
60+
c.customEval = fn
61+
}
62+
}
63+
5464
type gzRoundtripper struct {
5565
parent http.RoundTripper
5666
acceptEncoding string
5767
withZstd, withGzip bool
68+
customEval func(header http.Header) bool
5869
}
5970

6071
func (g *gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
@@ -82,16 +93,22 @@ func (g *gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
8293
if err != nil || !requestedComp {
8394
return resp, err
8495
}
85-
96+
decompress := false
97+
if g.customEval != nil {
98+
if !g.customEval(resp.Header) {
99+
return resp, nil
100+
}
101+
decompress = true
102+
}
86103
// Decompress
87-
if g.withGzip && asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
104+
if (decompress || g.withGzip) && asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
88105
resp.Body = &gzipReader{body: resp.Body}
89106
resp.Header.Del("Content-Encoding")
90107
resp.Header.Del("Content-Length")
91108
resp.ContentLength = -1
92109
resp.Uncompressed = true
93110
}
94-
if g.withZstd && asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") {
111+
if (decompress || g.withZstd) && asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") {
95112
resp.Body = &zstdReader{body: resp.Body}
96113
resp.Header.Del("Content-Encoding")
97114
resp.Header.Del("Content-Length")

gzhttp/transport_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,51 @@ func TestDefaultTransport(t *testing.T) {
206206
}
207207
}
208208

209+
func TestTransportCustomEval(t *testing.T) {
210+
bin, err := os.ReadFile("testdata/benchmark.json")
211+
if err != nil {
212+
t.Fatal(err)
213+
}
214+
215+
server := httptest.NewServer(newTestHandler(bin))
216+
calledWith := ""
217+
c := http.Client{Transport: Transport(http.DefaultTransport, TransportEnableZstd(false), TransportCustomEval(func(h http.Header) bool {
218+
calledWith = h.Get("Content-Encoding")
219+
return true
220+
}))}
221+
resp, err := c.Get(server.URL)
222+
if err != nil {
223+
t.Fatal(err)
224+
}
225+
got, err := io.ReadAll(resp.Body)
226+
if err != nil {
227+
t.Fatal(err)
228+
}
229+
if !bytes.Equal(got, bin) {
230+
t.Errorf("data mismatch")
231+
}
232+
if calledWith != "gzip" {
233+
t.Fatalf("Expected encoding %q, got %q", "gzip", calledWith)
234+
}
235+
// Test returning false
236+
c = http.Client{Transport: Transport(http.DefaultTransport, TransportCustomEval(func(h http.Header) bool {
237+
calledWith = h.Get("Content-Encoding")
238+
return false
239+
}))}
240+
resp, err = c.Get(server.URL)
241+
if err != nil {
242+
t.Fatal(err)
243+
}
244+
// Check we got the compressed data
245+
gotCE := resp.Header.Get("Content-Encoding")
246+
if gotCE != "gzip" {
247+
t.Fatalf("Expected encoding %q, got %q", "gzip", gotCE)
248+
}
249+
if calledWith != "gzip" {
250+
t.Fatalf("Expected encoding %q, got %q", "gzip", calledWith)
251+
}
252+
}
253+
209254
func BenchmarkTransport(b *testing.B) {
210255
raw, err := os.ReadFile("testdata/benchmark.json")
211256
if err != nil {

0 commit comments

Comments
 (0)