Skip to content

Commit 69a8ecc

Browse files
authored
zstd: Fix amd64 not always detecting corrupt data (#785)
* zstd: Fix amd64 not always detecting corrupt data Fix undetected corrupt data in amd64 assembly. In rare cases overreads would not get returned as errors, if a multiple of 256 bits was overread. This would make the "bitsread" equal the expected 64. Whenever all bytes has been read from memory we start checking if more than 64 bits has been read on every fill. This ensures that an overflow can never occur. No invalid memory was accessed, this is merely a question if errors are reported. Fixes https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=57290
1 parent 7633d62 commit 69a8ecc

File tree

7 files changed

+160
-23
lines changed

7 files changed

+160
-23
lines changed

internal/fuzz/helpers.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ func AddFromZip(f *testing.F, filename string, t InputType, short bool) {
6262
t = TypeRaw // Fallback
6363
if len(b) >= 4 {
6464
sz := binary.BigEndian.Uint32(b)
65-
if sz == uint32(len(b))-4 {
66-
f.Add(b[4:])
65+
if sz <= uint32(len(b))-4 {
66+
f.Add(b[4 : 4+sz])
6767
continue
6868
}
6969
}

zstd/_generate/gen.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ const errorNotEnoughLiterals = 4
3434
// error reported when capacity of `out` is too small
3535
const errorNotEnoughSpace = 5
3636

37+
// error reported when bits are overread.
38+
const errorOverread = 6
39+
3740
const maxMatchLen = 131074
3841

3942
// size of struct seqVals
@@ -247,8 +250,9 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute
247250
{
248251
brPointer := GP64()
249252
MOVQ(brPointerStash, brPointer)
253+
250254
Comment("Fill bitreader to have enough for the offset and match length.")
251-
o.bitreaderFill(name+"_fill", brValue, brBitsRead, brOffset, brPointer)
255+
o.bitreaderFill(name+"_fill", brValue, brBitsRead, brOffset, brPointer, LabelRef("error_overread"))
252256

253257
Comment("Update offset")
254258
// Up to 32 extra bits
@@ -261,7 +265,7 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute
261265
// If we need more than 56 in total, we must refill here.
262266
if !o.fiftysix {
263267
Comment("Fill bitreader to have enough for the remaining")
264-
o.bitreaderFill(name+"_fill_2", brValue, brBitsRead, brOffset, brPointer)
268+
o.bitreaderFill(name+"_fill_2", brValue, brBitsRead, brOffset, brPointer, LabelRef("error_overread"))
265269
}
266270

267271
Comment("Update literal length")
@@ -502,6 +506,12 @@ func (o options) generateBody(name string, executeSingleTriple func(ctx *execute
502506
o.returnWithCode(errorNotEnoughLiterals)
503507
}
504508

509+
Comment("Return with overread error")
510+
{
511+
Label("error_overread")
512+
o.returnWithCode(errorOverread)
513+
}
514+
505515
if !o.useSeqs {
506516
Comment("Return with not enough output space error")
507517
Label("error_not_enough_space")
@@ -529,7 +539,7 @@ func (o options) returnWithCode(returnCode uint32) {
529539
}
530540

531541
// bitreaderFill will make sure at least 56 bits are available.
532-
func (o options) bitreaderFill(name string, brValue, brBitsRead, brOffset, brPointer reg.GPVirtual) {
542+
func (o options) bitreaderFill(name string, brValue, brBitsRead, brOffset, brPointer reg.GPVirtual, overread LabelRef) {
533543
// bitreader_fill begin
534544
CMPQ(brOffset, U8(8)) // b.off >= 8
535545
JL(LabelRef(name + "_byte_by_byte"))
@@ -545,7 +555,7 @@ func (o options) bitreaderFill(name string, brValue, brBitsRead, brOffset, brPoi
545555

546556
Label(name + "_byte_by_byte")
547557
CMPQ(brOffset, U8(0)) /* for b.off > 0 */
548-
JLE(LabelRef(name + "_end"))
558+
JLE(LabelRef(name + "_check_overread"))
549559

550560
CMPQ(brBitsRead, U8(7)) /* for brBitsRead > 7 */
551561
JLE(LabelRef(name + "_end"))
@@ -565,6 +575,10 @@ func (o options) bitreaderFill(name string, brValue, brBitsRead, brOffset, brPoi
565575
}
566576
JMP(LabelRef(name + "_byte_by_byte"))
567577

578+
Label(name + "_check_overread")
579+
CMPQ(brBitsRead, U8(64))
580+
JA(overread)
581+
568582
Label(name + "_end")
569583
}
570584

zstd/fuzz_test.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ func FuzzDecAllNoBMI2(f *testing.F) {
6666
func FuzzDecoder(f *testing.F) {
6767
fuzz.AddFromZip(f, "testdata/fuzz/decode-corpus-raw.zip", fuzz.TypeRaw, testing.Short())
6868
fuzz.AddFromZip(f, "testdata/fuzz/decode-corpus-encoded.zip", fuzz.TypeGoFuzz, testing.Short())
69+
//fuzz.AddFromZip(f, "testdata/fuzz/decode-oss.zip", fuzz.TypeOSSFuzz, false)
6970

7071
brLow := newBytesReader(nil)
7172
brHi := newBytesReader(nil)
@@ -92,18 +93,25 @@ func FuzzDecoder(f *testing.F) {
9293
}
9394
defer decHi.Close()
9495

96+
if debugDecoder {
97+
fmt.Println("LOW CONCURRENT")
98+
}
9599
b1, err1 := io.ReadAll(decLow)
100+
101+
if debugDecoder {
102+
fmt.Println("HI NOT CONCURRENT")
103+
}
96104
b2, err2 := io.ReadAll(decHi)
97105
if err1 != err2 {
98106
if (err1 == nil) != (err2 == nil) {
99-
t.Errorf("err low: %v, hi: %v", err1, err2)
107+
t.Errorf("err low concurrent: %v, hi: %v", err1, err2)
100108
}
101109
}
102110
if err1 != nil {
103111
b1, b2 = b1[:0], b2[:0]
104112
}
105113
if !bytes.Equal(b1, b2) {
106-
t.Fatalf("Output mismatch, low: %v, hi: %v", err1, err2)
114+
t.Fatalf("Output mismatch, low concurrent: %v, hi: %v", err1, err2)
107115
}
108116
})
109117
}

zstd/seqdec.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,9 +236,12 @@ func (s *sequenceDecs) decodeSync(hist []byte) error {
236236
maxBlockSize = s.windowSize
237237
}
238238

239+
if debugDecoder {
240+
println("decodeSync: decoding", seqs, "sequences", br.remain(), "bits remain on stream")
241+
}
239242
for i := seqs - 1; i >= 0; i-- {
240243
if br.overread() {
241-
printf("reading sequence %d, exceeded available data\n", seqs-i)
244+
printf("reading sequence %d, exceeded available data. Overread by %d\n", seqs-i, -br.remain())
242245
return io.ErrUnexpectedEOF
243246
}
244247
var ll, mo, ml int

zstd/seqdec_amd64.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package zstd
55

66
import (
77
"fmt"
8+
"io"
89

910
"github.com/klauspost/compress/internal/cpuinfo"
1011
)
@@ -134,6 +135,9 @@ func (s *sequenceDecs) decodeSyncSimple(hist []byte) (bool, error) {
134135
return true, fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available",
135136
ctx.ll, ctx.litRemain+ctx.ll)
136137

138+
case errorOverread:
139+
return true, io.ErrUnexpectedEOF
140+
137141
case errorNotEnoughSpace:
138142
size := ctx.outPosition + ctx.ll + ctx.ml
139143
if debugDecoder {
@@ -202,6 +206,9 @@ const errorNotEnoughLiterals = 4
202206
// error reported when capacity of `out` is too small
203207
const errorNotEnoughSpace = 5
204208

209+
// error reported when bits are overread.
210+
const errorOverread = 6
211+
205212
// sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm.
206213
//
207214
// Please refer to seqdec_generic.go for the reference implementation.
@@ -247,6 +254,10 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
247254
litRemain: len(s.literals),
248255
}
249256

257+
if debugDecoder {
258+
println("decode: decoding", len(seqs), "sequences", br.remain(), "bits remain on stream")
259+
}
260+
250261
s.seqSize = 0
251262
lte56bits := s.maxBits+s.offsets.fse.actualTableLog+s.matchLengths.fse.actualTableLog+s.litLengths.fse.actualTableLog <= 56
252263
var errCode int
@@ -277,6 +288,8 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
277288
case errorNotEnoughLiterals:
278289
ll := ctx.seqs[i].ll
279290
return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, ctx.litRemain+ll)
291+
case errorOverread:
292+
return io.ErrUnexpectedEOF
280293
}
281294

282295
return fmt.Errorf("sequenceDecs_decode_amd64 returned erronous code %d", errCode)
@@ -291,6 +304,9 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
291304
if s.seqSize > maxBlockSize {
292305
return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
293306
}
307+
if debugDecoder {
308+
println("decode: ", br.remain(), "bits remain on stream. code:", errCode)
309+
}
294310
err := br.close()
295311
if err != nil {
296312
printf("Closing sequences: %v, %+v\n", err, *br)

0 commit comments

Comments
 (0)