Better fuzz testing and fix several bugs it found

Fix infinite loop in AuthenticationSASL.Decode
Fix panic in CommandComplete.Decode
Fix panic in DataRow.Decode
Fix panic in NotificationResponse.Decode
pull/1281/head
Jack Christensen 2022-07-23 16:13:06 -05:00
parent 9d0f27bc4b
commit 7f382f5190
11 changed files with 67 additions and 16 deletions

View File

@ -36,11 +36,12 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
authMechanisms := src[4:] authMechanisms := src[4:]
for len(authMechanisms) > 1 { for len(authMechanisms) > 1 {
idx := bytes.IndexByte(authMechanisms, 0) idx := bytes.IndexByte(authMechanisms, 0)
if idx > 0 { if idx == -1 {
return &invalidMessageFormatErr{messageType: "AuthenticationSASL", details: "unterminated string"}
}
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx])) dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
authMechanisms = authMechanisms[idx+1:] authMechanisms = authMechanisms[idx+1:]
} }
}
return nil return nil
} }

View File

@ -18,8 +18,11 @@ func (*CommandComplete) Backend() {}
// type identifier and 4 byte message length. // type identifier and 4 byte message length.
func (dst *CommandComplete) Decode(src []byte) error { func (dst *CommandComplete) Decode(src []byte) error {
idx := bytes.IndexByte(src, 0) idx := bytes.IndexByte(src, 0)
if idx == -1 {
return &invalidMessageFormatErr{messageType: "CommandComplete", details: "unterminated string"}
}
if idx != len(src)-1 { if idx != len(src)-1 {
return &invalidMessageFormatErr{messageType: "CommandComplete"} return &invalidMessageFormatErr{messageType: "CommandComplete", details: "string terminated too early"}
} }
dst.CommandTag = src[:idx] dst.CommandTag = src[:idx]

View File

@ -43,19 +43,19 @@ func (dst *DataRow) Decode(src []byte) error {
return &invalidMessageFormatErr{messageType: "DataRow"} return &invalidMessageFormatErr{messageType: "DataRow"}
} }
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) valueLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4 rp += 4
// null // null
if msgSize == -1 { if valueLen == -1 {
dst.Values[i] = nil dst.Values[i] = nil
} else { } else {
if len(src[rp:]) < msgSize { if len(src[rp:]) < valueLen || valueLen < 0 {
return &invalidMessageFormatErr{messageType: "DataRow"} return &invalidMessageFormatErr{messageType: "DataRow"}
} }
dst.Values[i] = src[rp : rp+msgSize : rp+msgSize] dst.Values[i] = src[rp : rp+valueLen : rp+valueLen]
rp += msgSize rp += valueLen
} }
} }

View File

@ -4,22 +4,50 @@ import (
"bytes" "bytes"
"testing" "testing"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func FuzzFrontend(f *testing.F) { func FuzzFrontend(f *testing.F) {
testcases := [][]byte{ testcases := []struct {
{'Z', 0, 0, 0, 5}, msgType byte
msgLen uint32
msgBody []byte
}{
{
msgType: 'Z',
msgLen: 2,
msgBody: []byte{'I'},
},
{
msgType: 'Z',
msgLen: 5,
msgBody: []byte{'I'},
},
} }
for _, tc := range testcases { for _, tc := range testcases {
f.Add(tc) f.Add(tc.msgType, tc.msgLen, tc.msgBody)
} }
f.Fuzz(func(t *testing.T, encodedMsg []byte) { f.Fuzz(func(t *testing.T, msgType byte, msgLen uint32, msgBody []byte) {
// Prune any msgLen > len(msgBody) because they would hang the test waiting for more input.
if int(msgLen) > len(msgBody)+4 {
return
}
// Prune any messages that are too long.
if msgLen > 128 || len(msgBody) > 128 {
return
}
r := &bytes.Buffer{} r := &bytes.Buffer{}
w := &bytes.Buffer{} w := &bytes.Buffer{}
fe := pgproto3.NewFrontend(r, w) fe := pgproto3.NewFrontend(r, w)
var encodedMsg []byte
encodedMsg = append(encodedMsg, msgType)
encodedMsg = pgio.AppendUint32(encodedMsg, msgLen)
encodedMsg = append(encodedMsg, msgBody...)
_, err := r.Write(encodedMsg) _, err := r.Write(encodedMsg)
require.NoError(t, err) require.NoError(t, err)

View File

@ -22,6 +22,10 @@ func (*NotificationResponse) Backend() {}
func (dst *NotificationResponse) Decode(src []byte) error { func (dst *NotificationResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src) buf := bytes.NewBuffer(src)
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "NotificationResponse", details: "too short"}
}
pid := binary.BigEndian.Uint32(buf.Next(4)) pid := binary.BigEndian.Uint32(buf.Next(4))
b, err := buf.ReadBytes(0) b, err := buf.ReadBytes(0)

View File

@ -46,10 +46,11 @@ func (e *invalidMessageLenErr) Error() string {
type invalidMessageFormatErr struct { type invalidMessageFormatErr struct {
messageType string messageType string
details string
} }
func (e *invalidMessageFormatErr) Error() string { func (e *invalidMessageFormatErr) Error() string {
return fmt.Sprintf("%s body is invalid", e.messageType) return fmt.Sprintf("%s body is invalid %s", e.messageType, e.details)
} }
type writeError struct { type writeError struct {

View File

@ -0,0 +1,4 @@
go test fuzz v1
byte('A')
uint32(5)
[]byte("0")

View File

@ -1,2 +0,0 @@
go test fuzz v1
[]byte("0\x00\x00\x00\x02")

View File

@ -0,0 +1,4 @@
go test fuzz v1
byte('D')
uint32(21)
[]byte("00\xb300000000000000")

View File

@ -0,0 +1,4 @@
go test fuzz v1
byte('C')
uint32(4)
[]byte("0")

View File

@ -0,0 +1,4 @@
go test fuzz v1
byte('R')
uint32(13)
[]byte("\x00\x00\x00\n0\x12\xebG\x8dI']G\xdac\x95\xb7\x18\xb0\x02\xe8m\xc2\x00\xef\x03\x12\x1b\xbdj\x10\x9f\xf9\xeb\xb8")