Skip to content

Commit 945c212

Browse files
committed
Backport fixes from pgx v5
Check for overflow on uint16 sizes in pgproto3 Do not allow protocol messages larger than ~1GB The PostgreSQL server will reject messages greater than ~1 GB anyway. However, worse than that is that a message that is larger than 4 GB could wrap the 32-bit integer message size and be interpreted by the server as multiple messages. This could allow a malicious client to inject arbitrary protocol messages. GHSA-mrww-27vc-gghv
1 parent 0c0f7b0 commit 945c212

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+359
-359
lines changed

authentication_cleartext_password.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,10 @@ func (dst *AuthenticationCleartextPassword) Decode(src []byte) error {
3535
}
3636

3737
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
38-
func (src *AuthenticationCleartextPassword) Encode(dst []byte) []byte {
39-
dst = append(dst, 'R')
40-
dst = pgio.AppendInt32(dst, 8)
38+
func (src *AuthenticationCleartextPassword) Encode(dst []byte) ([]byte, error) {
39+
dst, sp := beginMessage(dst, 'R')
4140
dst = pgio.AppendUint32(dst, AuthTypeCleartextPassword)
42-
return dst
41+
return finishMessage(dst, sp)
4342
}
4443

4544
// MarshalJSON implements encoding/json.Marshaler.

authentication_gss.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/binary"
55
"encoding/json"
66
"errors"
7+
78
"github.com/jackc/pgio"
89
)
910

@@ -26,11 +27,10 @@ func (a *AuthenticationGSS) Decode(src []byte) error {
2627
return nil
2728
}
2829

29-
func (a *AuthenticationGSS) Encode(dst []byte) []byte {
30-
dst = append(dst, 'R')
31-
dst = pgio.AppendInt32(dst, 4)
30+
func (a *AuthenticationGSS) Encode(dst []byte) ([]byte, error) {
31+
dst, sp := beginMessage(dst, 'R')
3232
dst = pgio.AppendUint32(dst, AuthTypeGSS)
33-
return dst
33+
return finishMessage(dst, sp)
3434
}
3535

3636
func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {

authentication_gss_continue.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"encoding/binary"
55
"encoding/json"
66
"errors"
7+
78
"github.com/jackc/pgio"
89
)
910

@@ -30,12 +31,11 @@ func (a *AuthenticationGSSContinue) Decode(src []byte) error {
3031
return nil
3132
}
3233

33-
func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
34-
dst = append(dst, 'R')
35-
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
34+
func (a *AuthenticationGSSContinue) Encode(dst []byte) ([]byte, error) {
35+
dst, sp := beginMessage(dst, 'R')
3636
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
3737
dst = append(dst, a.Data...)
38-
return dst
38+
return finishMessage(dst, sp)
3939
}
4040

4141
func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {

authentication_md5_password.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,11 @@ func (dst *AuthenticationMD5Password) Decode(src []byte) error {
3838
}
3939

4040
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
41-
func (src *AuthenticationMD5Password) Encode(dst []byte) []byte {
42-
dst = append(dst, 'R')
43-
dst = pgio.AppendInt32(dst, 12)
41+
func (src *AuthenticationMD5Password) Encode(dst []byte) ([]byte, error) {
42+
dst, sp := beginMessage(dst, 'R')
4443
dst = pgio.AppendUint32(dst, AuthTypeMD5Password)
4544
dst = append(dst, src.Salt[:]...)
46-
return dst
45+
return finishMessage(dst, sp)
4746
}
4847

4948
// MarshalJSON implements encoding/json.Marshaler.

authentication_ok.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,10 @@ func (dst *AuthenticationOk) Decode(src []byte) error {
3535
}
3636

3737
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
38-
func (src *AuthenticationOk) Encode(dst []byte) []byte {
39-
dst = append(dst, 'R')
40-
dst = pgio.AppendInt32(dst, 8)
38+
func (src *AuthenticationOk) Encode(dst []byte) ([]byte, error) {
39+
dst, sp := beginMessage(dst, 'R')
4140
dst = pgio.AppendUint32(dst, AuthTypeOk)
42-
return dst
41+
return finishMessage(dst, sp)
4342
}
4443

4544
// MarshalJSON implements encoding/json.Marshaler.

authentication_sasl.go

+3-7
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,8 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
4646
}
4747

4848
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
49-
func (src *AuthenticationSASL) Encode(dst []byte) []byte {
50-
dst = append(dst, 'R')
51-
sp := len(dst)
52-
dst = pgio.AppendInt32(dst, -1)
49+
func (src *AuthenticationSASL) Encode(dst []byte) ([]byte, error) {
50+
dst, sp := beginMessage(dst, 'R')
5351
dst = pgio.AppendUint32(dst, AuthTypeSASL)
5452

5553
for _, s := range src.AuthMechanisms {
@@ -58,9 +56,7 @@ func (src *AuthenticationSASL) Encode(dst []byte) []byte {
5856
}
5957
dst = append(dst, 0)
6058

61-
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
62-
63-
return dst
59+
return finishMessage(dst, sp)
6460
}
6561

6662
// MarshalJSON implements encoding/json.Marshaler.

authentication_sasl_continue.go

+3-9
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,11 @@ func (dst *AuthenticationSASLContinue) Decode(src []byte) error {
3838
}
3939

4040
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
41-
func (src *AuthenticationSASLContinue) Encode(dst []byte) []byte {
42-
dst = append(dst, 'R')
43-
sp := len(dst)
44-
dst = pgio.AppendInt32(dst, -1)
41+
func (src *AuthenticationSASLContinue) Encode(dst []byte) ([]byte, error) {
42+
dst, sp := beginMessage(dst, 'R')
4543
dst = pgio.AppendUint32(dst, AuthTypeSASLContinue)
46-
4744
dst = append(dst, src.Data...)
48-
49-
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
50-
51-
return dst
45+
return finishMessage(dst, sp)
5246
}
5347

5448
// MarshalJSON implements encoding/json.Marshaler.

authentication_sasl_final.go

+3-9
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,11 @@ func (dst *AuthenticationSASLFinal) Decode(src []byte) error {
3838
}
3939

4040
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
41-
func (src *AuthenticationSASLFinal) Encode(dst []byte) []byte {
42-
dst = append(dst, 'R')
43-
sp := len(dst)
44-
dst = pgio.AppendInt32(dst, -1)
41+
func (src *AuthenticationSASLFinal) Encode(dst []byte) ([]byte, error) {
42+
dst, sp := beginMessage(dst, 'R')
4543
dst = pgio.AppendUint32(dst, AuthTypeSASLFinal)
46-
4744
dst = append(dst, src.Data...)
48-
49-
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
50-
51-
return dst
45+
return finishMessage(dst, sp)
5246
}
5347

5448
// MarshalJSON implements encoding/json.Unmarshaler.

backend.go

+10-5
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ func NewBackend(cr ChunkReader, w io.Writer) *Backend {
4949

5050
// Send sends a message to the frontend.
5151
func (b *Backend) Send(msg BackendMessage) error {
52-
_, err := b.w.Write(msg.Encode(nil))
52+
buf, err := msg.Encode(nil)
53+
if err != nil {
54+
return err
55+
}
56+
57+
_, err = b.w.Write(buf)
5358
return err
5459
}
5560

@@ -184,11 +189,11 @@ func (b *Backend) Receive() (FrontendMessage, error) {
184189
// contextual identification of FrontendMessages. For example, in the
185190
// PG message flow documentation for PasswordMessage:
186191
//
187-
// Byte1('p')
192+
// Byte1('p')
188193
//
189-
// Identifies the message as a password response. Note that this is also used for
190-
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
191-
// the context.
194+
// Identifies the message as a password response. Note that this is also used for
195+
// GSSAPI, SSPI and SASL response messages. The exact message type can be deduced from
196+
// the context.
192197
//
193198
// Since the Frontend does not know about the state of a backend, it is important
194199
// to call SetAuthType() after an authentication request is received by the Frontend.

backend_key_data.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,11 @@ func (dst *BackendKeyData) Decode(src []byte) error {
2929
}
3030

3131
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
32-
func (src *BackendKeyData) Encode(dst []byte) []byte {
33-
dst = append(dst, 'K')
34-
dst = pgio.AppendUint32(dst, 12)
32+
func (src *BackendKeyData) Encode(dst []byte) ([]byte, error) {
33+
dst, sp := beginMessage(dst, 'K')
3534
dst = pgio.AppendUint32(dst, src.ProcessID)
3635
dst = pgio.AppendUint32(dst, src.SecretKey)
37-
return dst
36+
return finishMessage(dst, sp)
3837
}
3938

4039
// MarshalJSON implements encoding/json.Marshaler.

backend_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ func TestStartupMessage(t *testing.T) {
7171
"username": "tester",
7272
},
7373
}
74-
dst := []byte{}
75-
dst = want.Encode(dst)
74+
dst, err := want.Encode([]byte{})
75+
require.NoError(t, err)
7676

7777
server := &interruptReader{}
7878
server.push(dst)

bind.go

+14-7
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ import (
55
"encoding/binary"
66
"encoding/hex"
77
"encoding/json"
8+
"errors"
89
"fmt"
10+
"math"
911

1012
"github.com/jackc/pgio"
1113
)
@@ -108,21 +110,25 @@ func (dst *Bind) Decode(src []byte) error {
108110
}
109111

110112
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
111-
func (src *Bind) Encode(dst []byte) []byte {
112-
dst = append(dst, 'B')
113-
sp := len(dst)
114-
dst = pgio.AppendInt32(dst, -1)
113+
func (src *Bind) Encode(dst []byte) ([]byte, error) {
114+
dst, sp := beginMessage(dst, 'B')
115115

116116
dst = append(dst, src.DestinationPortal...)
117117
dst = append(dst, 0)
118118
dst = append(dst, src.PreparedStatement...)
119119
dst = append(dst, 0)
120120

121+
if len(src.ParameterFormatCodes) > math.MaxUint16 {
122+
return nil, errors.New("too many parameter format codes")
123+
}
121124
dst = pgio.AppendUint16(dst, uint16(len(src.ParameterFormatCodes)))
122125
for _, fc := range src.ParameterFormatCodes {
123126
dst = pgio.AppendInt16(dst, fc)
124127
}
125128

129+
if len(src.Parameters) > math.MaxUint16 {
130+
return nil, errors.New("too many parameters")
131+
}
126132
dst = pgio.AppendUint16(dst, uint16(len(src.Parameters)))
127133
for _, p := range src.Parameters {
128134
if p == nil {
@@ -134,14 +140,15 @@ func (src *Bind) Encode(dst []byte) []byte {
134140
dst = append(dst, p...)
135141
}
136142

143+
if len(src.ResultFormatCodes) > math.MaxUint16 {
144+
return nil, errors.New("too many result format codes")
145+
}
137146
dst = pgio.AppendUint16(dst, uint16(len(src.ResultFormatCodes)))
138147
for _, fc := range src.ResultFormatCodes {
139148
dst = pgio.AppendInt16(dst, fc)
140149
}
141150

142-
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
143-
144-
return dst
151+
return finishMessage(dst, sp)
145152
}
146153

147154
// MarshalJSON implements encoding/json.Marshaler.

bind_complete.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ func (dst *BindComplete) Decode(src []byte) error {
2020
}
2121

2222
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
23-
func (src *BindComplete) Encode(dst []byte) []byte {
24-
return append(dst, '2', 0, 0, 0, 4)
23+
func (src *BindComplete) Encode(dst []byte) ([]byte, error) {
24+
return append(dst, '2', 0, 0, 0, 4), nil
2525
}
2626

2727
// MarshalJSON implements encoding/json.Marshaler.

bind_test.go

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package pgproto3_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/jackc/pgproto3/v2"
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestBindBiggerThanMaxMessageBodyLen(t *testing.T) {
11+
t.Parallel()
12+
13+
// Maximum allowed size.
14+
_, err := (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-16)}}).Encode(nil)
15+
require.NoError(t, err)
16+
17+
// 1 byte too big
18+
_, err = (&pgproto3.Bind{Parameters: [][]byte{make([]byte, pgproto3.MaxMessageBodyLen-15)}}).Encode(nil)
19+
require.Error(t, err)
20+
}

cancel_request.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ func (dst *CancelRequest) Decode(src []byte) error {
3636
}
3737

3838
// Encode encodes src into dst. dst will include the 4 byte message length.
39-
func (src *CancelRequest) Encode(dst []byte) []byte {
39+
func (src *CancelRequest) Encode(dst []byte) ([]byte, error) {
4040
dst = pgio.AppendInt32(dst, 16)
4141
dst = pgio.AppendInt32(dst, cancelRequestCode)
4242
dst = pgio.AppendUint32(dst, src.ProcessID)
4343
dst = pgio.AppendUint32(dst, src.SecretKey)
44-
return dst
44+
return dst, nil
4545
}
4646

4747
// MarshalJSON implements encoding/json.Marshaler.

close.go

+3-11
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ import (
44
"bytes"
55
"encoding/json"
66
"errors"
7-
8-
"github.com/jackc/pgio"
97
)
108

119
type Close struct {
@@ -37,18 +35,12 @@ func (dst *Close) Decode(src []byte) error {
3735
}
3836

3937
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
40-
func (src *Close) Encode(dst []byte) []byte {
41-
dst = append(dst, 'C')
42-
sp := len(dst)
43-
dst = pgio.AppendInt32(dst, -1)
44-
38+
func (src *Close) Encode(dst []byte) ([]byte, error) {
39+
dst, sp := beginMessage(dst, 'C')
4540
dst = append(dst, src.ObjectType)
4641
dst = append(dst, src.Name...)
4742
dst = append(dst, 0)
48-
49-
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
50-
51-
return dst
43+
return finishMessage(dst, sp)
5244
}
5345

5446
// MarshalJSON implements encoding/json.Marshaler.

close_complete.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ func (dst *CloseComplete) Decode(src []byte) error {
2020
}
2121

2222
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
23-
func (src *CloseComplete) Encode(dst []byte) []byte {
24-
return append(dst, '3', 0, 0, 0, 4)
23+
func (src *CloseComplete) Encode(dst []byte) ([]byte, error) {
24+
return append(dst, '3', 0, 0, 0, 4), nil
2525
}
2626

2727
// MarshalJSON implements encoding/json.Marshaler.

command_complete.go

+3-11
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package pgproto3
33
import (
44
"bytes"
55
"encoding/json"
6-
7-
"github.com/jackc/pgio"
86
)
97

108
type CommandComplete struct {
@@ -28,17 +26,11 @@ func (dst *CommandComplete) Decode(src []byte) error {
2826
}
2927

3028
// Encode encodes src into dst. dst will include the 1 byte message type identifier and the 4 byte message length.
31-
func (src *CommandComplete) Encode(dst []byte) []byte {
32-
dst = append(dst, 'C')
33-
sp := len(dst)
34-
dst = pgio.AppendInt32(dst, -1)
35-
29+
func (src *CommandComplete) Encode(dst []byte) ([]byte, error) {
30+
dst, sp := beginMessage(dst, 'C')
3631
dst = append(dst, src.CommandTag...)
3732
dst = append(dst, 0)
38-
39-
pgio.SetInt32(dst[sp:], int32(len(dst[sp:])))
40-
41-
return dst
33+
return finishMessage(dst, sp)
4234
}
4335

4436
// MarshalJSON implements encoding/json.Marshaler.

0 commit comments

Comments
 (0)