Skip to content

Speed up huff0 table decode #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 84 additions & 35 deletions huff0/decompress.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ type dTable struct {

// single-symbols decoding
type dEntrySingle struct {
byte uint8
nBits uint8
entry uint16
}

// double-symbols decoding
Expand Down Expand Up @@ -76,14 +75,15 @@ func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) {
}

// collect weight stats
var rankStats [tableLogMax + 1]uint32
var rankStats [16]uint32
weightTotal := uint32(0)
for _, v := range s.huffWeight[:s.symbolLen] {
if v > tableLogMax {
return s, nil, errors.New("corrupt input: weight too large")
}
rankStats[v]++
weightTotal += (1 << (v & 15)) >> 1
v2 := v & 15
rankStats[v2]++
weightTotal += (1 << v2) >> 1
}
if weightTotal == 0 {
return s, nil, errors.New("corrupt input: weights zero")
Expand Down Expand Up @@ -134,15 +134,17 @@ func ReadTable(in []byte, s *Scratch) (s2 *Scratch, remain []byte, err error) {
if len(s.dt.single) != tSize {
s.dt.single = make([]dEntrySingle, tSize)
}

for n, w := range s.huffWeight[:s.symbolLen] {
if w == 0 {
continue
}
length := (uint32(1) << w) >> 1
d := dEntrySingle{
byte: uint8(n),
nBits: s.actualTableLog + 1 - w,
entry: uint16(s.actualTableLog+1-w) | (uint16(n) << 8),
}
for u := rankStats[w]; u < rankStats[w]+length; u++ {
s.dt.single[u] = d
single := s.dt.single[rankStats[w] : rankStats[w]+length]
for i := range single {
single[i] = d
}
rankStats[w] += length
}
Expand All @@ -167,12 +169,12 @@ func (s *Scratch) Decompress1X(in []byte) (out []byte, err error) {
decode := func() byte {
val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */
v := s.dt.single[val]
br.bitsRead += v.nBits
return v.byte
br.bitsRead += uint8(v.entry)
return uint8(v.entry >> 8)
}
hasDec := func(v dEntrySingle) byte {
br.bitsRead += v.nBits
return v.byte
br.bitsRead += uint8(v.entry)
return uint8(v.entry >> 8)
}

// Avoid bounds check by always having full sized table.
Expand Down Expand Up @@ -269,8 +271,8 @@ func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) {
decode := func(br *bitReader) byte {
val := br.peekBitsFast(s.actualTableLog) /* note : actualTableLog >= 1 */
v := single[val&tlMask]
br.bitsRead += v.nBits
return v.byte
br.bitsRead += uint8(v.entry)
return uint8(v.entry >> 8)
}

// Use temp table to avoid bound checks/append penalty.
Expand All @@ -283,20 +285,67 @@ func (s *Scratch) Decompress4X(in []byte, dstSize int) (out []byte, err error) {
bigloop:
for {
for i := range br {
if br[i].off < 4 {
br := &br[i]
if br.off < 4 {
break bigloop
}
br[i].fillFast()
br.fillFast()
}

{
const stream = 0
val := br[stream].peekBitsFast(s.actualTableLog)
v := single[val&tlMask]
br[stream].bitsRead += uint8(v.entry)

val2 := br[stream].peekBitsFast(s.actualTableLog)
v2 := single[val2&tlMask]
tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
tmp[off+bufoff*stream] = uint8(v.entry >> 8)
br[stream].bitsRead += uint8(v2.entry)
}

{
const stream = 1
val := br[stream].peekBitsFast(s.actualTableLog)
v := single[val&tlMask]
br[stream].bitsRead += uint8(v.entry)

val2 := br[stream].peekBitsFast(s.actualTableLog)
v2 := single[val2&tlMask]
tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
tmp[off+bufoff*stream] = uint8(v.entry >> 8)
br[stream].bitsRead += uint8(v2.entry)
}

{
const stream = 2
val := br[stream].peekBitsFast(s.actualTableLog)
v := single[val&tlMask]
br[stream].bitsRead += uint8(v.entry)

val2 := br[stream].peekBitsFast(s.actualTableLog)
v2 := single[val2&tlMask]
tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
tmp[off+bufoff*stream] = uint8(v.entry >> 8)
br[stream].bitsRead += uint8(v2.entry)
}
tmp[off] = decode(&br[0])
tmp[off+bufoff] = decode(&br[1])
tmp[off+bufoff*2] = decode(&br[2])
tmp[off+bufoff*3] = decode(&br[3])
tmp[off+1] = decode(&br[0])
tmp[off+1+bufoff] = decode(&br[1])
tmp[off+1+bufoff*2] = decode(&br[2])
tmp[off+1+bufoff*3] = decode(&br[3])

{
const stream = 3
val := br[stream].peekBitsFast(s.actualTableLog)
v := single[val&tlMask]
br[stream].bitsRead += uint8(v.entry)

val2 := br[stream].peekBitsFast(s.actualTableLog)
v2 := single[val2&tlMask]
tmp[off+bufoff*stream+1] = uint8(v2.entry >> 8)
tmp[off+bufoff*stream] = uint8(v.entry >> 8)
br[stream].bitsRead += uint8(v2.entry)
}

off += 2

if off == bufoff {
if bufoff > dstEvery {
return nil, errors.New("corruption detected: stream overrun 1")
Expand Down Expand Up @@ -367,7 +416,7 @@ func (s *Scratch) matches(ct cTable, w io.Writer) {
broken++
if enc.nBits == 0 {
for _, dec := range dt {
if dec.byte == byte(sym) {
if uint8(dec.entry>>8) == byte(sym) {
fmt.Fprintf(w, "symbol %x has decoder, but no encoder\n", sym)
errs++
break
Expand All @@ -383,12 +432,12 @@ func (s *Scratch) matches(ct cTable, w io.Writer) {
top := enc.val << ub
// decoder looks at top bits.
dec := dt[top]
if dec.nBits != enc.nBits {
fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, dec.nBits)
if uint8(dec.entry) != enc.nBits {
fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", sym, enc.nBits, uint8(dec.entry))
errs++
}
if dec.byte != uint8(sym) {
fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, dec.byte)
if uint8(dec.entry>>8) != uint8(sym) {
fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", sym, sym, uint8(dec.entry>>8))
errs++
}
if errs > 0 {
Expand All @@ -399,12 +448,12 @@ func (s *Scratch) matches(ct cTable, w io.Writer) {
for i := uint16(0); i < (1 << ub); i++ {
vval := top | i
dec := dt[vval]
if dec.nBits != enc.nBits {
fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, dec.nBits)
if uint8(dec.entry) != enc.nBits {
fmt.Fprintf(w, "symbol 0x%x bit size mismatch (enc: %d, dec:%d).\n", vval, enc.nBits, uint8(dec.entry))
errs++
}
if dec.byte != uint8(sym) {
fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, dec.byte)
if uint8(dec.entry>>8) != uint8(sym) {
fmt.Fprintf(w, "symbol 0x%x decoder output mismatch (enc: %d, dec:%d).\n", vval, sym, uint8(dec.entry>>8))
errs++
}
if errs > 20 {
Expand Down
38 changes: 38 additions & 0 deletions huff0/decompress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,3 +478,41 @@ func BenchmarkDecompress4XNoTable(b *testing.B) {
})
}
}

func BenchmarkDecompress4XTable(b *testing.B) {
for _, tt := range testfiles {
test := tt
if test.err4X != nil {
continue
}
b.Run(test.name, func(b *testing.B) {
var s = &Scratch{}
s.Reuse = ReusePolicyNone
buf0, err := test.fn()
if err != nil {
b.Fatal(err)
}
if len(buf0) > BlockSizeMax {
buf0 = buf0[:BlockSizeMax]
}
compressed, _, err := Compress4X(buf0, s)
if err != test.err1X {
b.Fatal("unexpected error:", err)
}
s.Out = nil
b.ResetTimer()
b.ReportAllocs()
b.SetBytes(int64(len(buf0)))
for i := 0; i < b.N; i++ {
s, remain, err := ReadTable(compressed, s)
if err != nil {
b.Fatal(err)
}
_, err = s.Decompress4X(remain, len(buf0))
if err != nil {
b.Fatal(err)
}
}
})
}
}
Binary file modified zstd/testdata/benchdecoder.zip
Binary file not shown.