Skip to content

Commit 901aaf2

Browse files
authored
zstd: Improve block encoding speed (#456)
* zstd: Improve block encoding speed * Unify loops, avoid check.
1 parent 6f71bfc commit 901aaf2

File tree

2 files changed

+70
-49
lines changed

2 files changed

+70
-49
lines changed

zstd/bitwriter.go

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,34 @@ func (b *bitWriter) addBits16NC(value uint16, bits uint8) {
3838
b.nBits += bits
3939
}
4040

41-
// addBits32NC will add up to 32 bits.
41+
// addBits32NC will add up to 31 bits.
4242
// It will not check if there is space for them,
4343
// so the caller must ensure that it has flushed recently.
4444
func (b *bitWriter) addBits32NC(value uint32, bits uint8) {
4545
b.bitContainer |= uint64(value&bitMask32[bits&31]) << (b.nBits & 63)
4646
b.nBits += bits
4747
}
4848

49+
// addBits64NC will add up to 64 bits.
50+
// There must be space for 32 bits.
51+
func (b *bitWriter) addBits64NC(value uint64, bits uint8) {
52+
if bits <= 31 {
53+
b.addBits32Clean(uint32(value), bits)
54+
return
55+
}
56+
b.addBits32Clean(uint32(value), 32)
57+
b.flush32()
58+
b.addBits32Clean(uint32(value>>32), bits-32)
59+
}
60+
61+
// addBits32Clean will add up to 32 bits.
62+
// It will not check if there is space for them.
63+
// The input must not contain more bits than specified.
64+
func (b *bitWriter) addBits32Clean(value uint32, bits uint8) {
65+
b.bitContainer |= uint64(value) << (b.nBits & 63)
66+
b.nBits += bits
67+
}
68+
4969
// addBits16Clean will add up to 16 bits. value may not contain more set bits than indicated.
5070
// It will not check if there is space for them, so the caller must ensure that it has flushed recently.
5171
func (b *bitWriter) addBits16Clean(value uint16, bits uint8) {

zstd/blockenc.go

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -722,52 +722,53 @@ func (b *blockEnc) encode(org []byte, raw, rawAllLits bool) error {
722722
println("Encoded seq", seq, s, "codes:", s.llCode, s.mlCode, s.ofCode, "states:", ll.state, ml.state, of.state, "bits:", llB, mlB, ofB)
723723
}
724724
seq--
725-
if llEnc.maxBits+mlEnc.maxBits+ofEnc.maxBits <= 32 {
726-
// No need to flush (common)
727-
for seq >= 0 {
728-
s = b.sequences[seq]
729-
wr.flush32()
730-
llB, ofB, mlB := llTT[s.llCode], ofTT[s.ofCode], mlTT[s.mlCode]
731-
// tabelog max is 8 for all.
732-
of.encode(ofB)
733-
ml.encode(mlB)
734-
ll.encode(llB)
735-
wr.flush32()
736-
737-
// We checked that all can stay within 32 bits
738-
wr.addBits32NC(s.litLen, llB.outBits)
739-
wr.addBits32NC(s.matchLen, mlB.outBits)
740-
wr.addBits32NC(s.offset, ofB.outBits)
741-
742-
if debugSequences {
743-
println("Encoded seq", seq, s)
744-
}
745-
746-
seq--
747-
}
748-
} else {
749-
for seq >= 0 {
750-
s = b.sequences[seq]
751-
wr.flush32()
752-
llB, ofB, mlB := llTT[s.llCode], ofTT[s.ofCode], mlTT[s.mlCode]
753-
// tabelog max is below 8 for each.
754-
of.encode(ofB)
755-
ml.encode(mlB)
756-
ll.encode(llB)
757-
wr.flush32()
758-
759-
// ml+ll = max 32 bits total
760-
wr.addBits32NC(s.litLen, llB.outBits)
761-
wr.addBits32NC(s.matchLen, mlB.outBits)
762-
wr.flush32()
763-
wr.addBits32NC(s.offset, ofB.outBits)
764-
765-
if debugSequences {
766-
println("Encoded seq", seq, s)
767-
}
768-
769-
seq--
770-
}
725+
// Store sequences in reverse...
726+
for seq >= 0 {
727+
s = b.sequences[seq]
728+
729+
ofB := ofTT[s.ofCode]
730+
wr.flush32() // tablelog max is below 8 for each, so it will fill max 24 bits.
731+
//of.encode(ofB)
732+
nbBitsOut := (uint32(of.state) + ofB.deltaNbBits) >> 16
733+
dstState := int32(of.state>>(nbBitsOut&15)) + int32(ofB.deltaFindState)
734+
wr.addBits16NC(of.state, uint8(nbBitsOut))
735+
of.state = of.stateTable[dstState]
736+
737+
// Accumulate extra bits.
738+
outBits := ofB.outBits & 31
739+
extraBits := uint64(s.offset & bitMask32[outBits])
740+
extraBitsN := outBits
741+
742+
mlB := mlTT[s.mlCode]
743+
//ml.encode(mlB)
744+
nbBitsOut = (uint32(ml.state) + mlB.deltaNbBits) >> 16
745+
dstState = int32(ml.state>>(nbBitsOut&15)) + int32(mlB.deltaFindState)
746+
wr.addBits16NC(ml.state, uint8(nbBitsOut))
747+
ml.state = ml.stateTable[dstState]
748+
749+
outBits = mlB.outBits & 31
750+
extraBits = extraBits<<outBits | uint64(s.matchLen&bitMask32[outBits])
751+
extraBitsN += outBits
752+
753+
llB := llTT[s.llCode]
754+
//ll.encode(llB)
755+
nbBitsOut = (uint32(ll.state) + llB.deltaNbBits) >> 16
756+
dstState = int32(ll.state>>(nbBitsOut&15)) + int32(llB.deltaFindState)
757+
wr.addBits16NC(ll.state, uint8(nbBitsOut))
758+
ll.state = ll.stateTable[dstState]
759+
760+
outBits = llB.outBits & 31
761+
extraBits = extraBits<<outBits | uint64(s.litLen&bitMask32[outBits])
762+
extraBitsN += outBits
763+
764+
wr.flush32()
765+
wr.addBits64NC(extraBits, extraBitsN)
766+
767+
if debugSequences {
768+
println("Encoded seq", seq, s)
769+
}
770+
771+
seq--
771772
}
772773
ml.flush(mlEnc.actualTableLog)
773774
of.flush(ofEnc.actualTableLog)
@@ -820,7 +821,8 @@ func (b *blockEnc) genCodes() {
820821
}
821822

822823
var llMax, ofMax, mlMax uint8
823-
for i, seq := range b.sequences {
824+
for i := range b.sequences {
825+
seq := &b.sequences[i]
824826
v := llCode(seq.litLen)
825827
seq.llCode = v
826828
llH[v]++
@@ -844,7 +846,6 @@ func (b *blockEnc) genCodes() {
844846
panic(fmt.Errorf("mlMax > maxMatchLengthSymbol (%d), matchlen: %d", mlMax, seq.matchLen))
845847
}
846848
}
847-
b.sequences[i] = seq
848849
}
849850
maxCount := func(a []uint32) int {
850851
var max uint32

0 commit comments

Comments
 (0)