Skip to content

Commit d172db7

Browse files
authored
zstd: Detect short invalid signatures (#382)
Detect short frame signatures. Fixes #381
1 parent e95c300 commit d172db7

File tree

2 files changed

+57
-11
lines changed

2 files changed

+57
-11
lines changed

zstd/decoder_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,31 @@ func TestDecoderRegression(t *testing.T) {
575575
}
576576
}
577577

578+
func TestShort(t *testing.T) {
579+
for _, in := range []string{"f", "fo", "foo"} {
580+
inb := []byte(in)
581+
dec, err := NewReader(nil)
582+
if err != nil {
583+
t.Fatal(err)
584+
}
585+
defer dec.Close()
586+
587+
t.Run(fmt.Sprintf("DecodeAll-%d", len(in)), func(t *testing.T) {
588+
_, err := dec.DecodeAll(inb, nil)
589+
if err == nil {
590+
t.Error("want error, got nil")
591+
}
592+
})
593+
t.Run(fmt.Sprintf("Reader-%d", len(in)), func(t *testing.T) {
594+
dec.Reset(bytes.NewReader(inb))
595+
_, err := io.Copy(ioutil.Discard, dec)
596+
if err == nil {
597+
t.Error("want error, got nil")
598+
}
599+
})
600+
}
601+
}
602+
578603
func TestDecoder_Reset(t *testing.T) {
579604
in, err := ioutil.ReadFile("testdata/z000028")
580605
if err != nil {

zstd/framedec.go

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,28 +78,43 @@ func newFrameDec(o decoderOptions) *frameDec {
7878
func (d *frameDec) reset(br byteBuffer) error {
7979
d.HasCheckSum = false
8080
d.WindowSize = 0
81-
var b []byte
81+
var signature [4]byte
8282
for {
8383
var err error
84-
b, err = br.readSmall(4)
84+
// Check if we can read more...
85+
b, err := br.readSmall(1)
8586
switch err {
8687
case io.EOF, io.ErrUnexpectedEOF:
8788
return io.EOF
8889
default:
8990
return err
9091
case nil:
92+
signature[0] = b[0]
93+
}
94+
// Read the rest, don't allow io.ErrUnexpectedEOF
95+
b, err = br.readSmall(3)
96+
switch err {
97+
case io.EOF:
98+
return io.EOF
99+
default:
100+
return err
101+
case nil:
102+
copy(signature[1:], b)
91103
}
92-
if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 {
104+
105+
if !bytes.Equal(signature[1:4], skippableFrameMagic) || signature[0]&0xf0 != 0x50 {
93106
if debugDecoder {
94-
println("Not skippable", hex.EncodeToString(b), hex.EncodeToString(skippableFrameMagic))
107+
println("Not skippable", hex.EncodeToString(signature[:]), hex.EncodeToString(skippableFrameMagic))
95108
}
96109
// Break if not skippable frame.
97110
break
98111
}
99112
// Read size to skip
100113
b, err = br.readSmall(4)
101114
if err != nil {
102-
println("Reading Frame Size", err)
115+
if debugDecoder {
116+
println("Reading Frame Size", err)
117+
}
103118
return err
104119
}
105120
n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
@@ -112,15 +127,19 @@ func (d *frameDec) reset(br byteBuffer) error {
112127
return err
113128
}
114129
}
115-
if !bytes.Equal(b, frameMagic) {
116-
println("Got magic numbers: ", b, "want:", frameMagic)
130+
if !bytes.Equal(signature[:], frameMagic) {
131+
if debugDecoder {
132+
println("Got magic numbers: ", signature, "want:", frameMagic)
133+
}
117134
return ErrMagicMismatch
118135
}
119136

120137
// Read Frame_Header_Descriptor
121138
fhd, err := br.readByte()
122139
if err != nil {
123-
println("Reading Frame_Header_Descriptor", err)
140+
if debugDecoder {
141+
println("Reading Frame_Header_Descriptor", err)
142+
}
124143
return err
125144
}
126145
d.SingleSegment = fhd&(1<<5) != 0
@@ -135,7 +154,9 @@ func (d *frameDec) reset(br byteBuffer) error {
135154
if !d.SingleSegment {
136155
wd, err := br.readByte()
137156
if err != nil {
138-
println("Reading Window_Descriptor", err)
157+
if debugDecoder {
158+
println("Reading Window_Descriptor", err)
159+
}
139160
return err
140161
}
141162
printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
@@ -153,7 +174,7 @@ func (d *frameDec) reset(br byteBuffer) error {
153174
size = 4
154175
}
155176

156-
b, err = br.readSmall(int(size))
177+
b, err := br.readSmall(int(size))
157178
if err != nil {
158179
println("Reading Dictionary_ID", err)
159180
return err
@@ -191,7 +212,7 @@ func (d *frameDec) reset(br byteBuffer) error {
191212
}
192213
d.FrameContentSize = 0
193214
if fcsSize > 0 {
194-
b, err = br.readSmall(fcsSize)
215+
b, err := br.readSmall(fcsSize)
195216
if err != nil {
196217
println("Reading Frame content", err)
197218
return err

0 commit comments

Comments
 (0)