Better fuzz testing and fix several bugs it found

Fix infinite loop in AuthenticationSASL.Decode
Fix panic in CommandComplete.Decode
Fix panic in DataRow.Decode
Fix panic in NotificationResponse.Decode
This commit is contained in:
Jack Christensen 2022-07-23 16:13:06 -05:00
parent 9d0f27bc4b
commit 7f382f5190
11 changed files with 67 additions and 16 deletions

View File

@ -36,10 +36,11 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
authMechanisms := src[4:] authMechanisms := src[4:]
for len(authMechanisms) > 1 { for len(authMechanisms) > 1 {
idx := bytes.IndexByte(authMechanisms, 0) idx := bytes.IndexByte(authMechanisms, 0)
if idx > 0 { if idx == -1 {
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx])) return &invalidMessageFormatErr{messageType: "AuthenticationSASL", details: "unterminated string"}
authMechanisms = authMechanisms[idx+1:]
} }
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
authMechanisms = authMechanisms[idx+1:]
} }
return nil return nil

View File

@ -18,8 +18,11 @@ func (*CommandComplete) Backend() {}
// type identifier and 4 byte message length. // type identifier and 4 byte message length.
func (dst *CommandComplete) Decode(src []byte) error { func (dst *CommandComplete) Decode(src []byte) error {
idx := bytes.IndexByte(src, 0) idx := bytes.IndexByte(src, 0)
if idx == -1 {
return &invalidMessageFormatErr{messageType: "CommandComplete", details: "unterminated string"}
}
if idx != len(src)-1 { if idx != len(src)-1 {
return &invalidMessageFormatErr{messageType: "CommandComplete"} return &invalidMessageFormatErr{messageType: "CommandComplete", details: "string terminated too early"}
} }
dst.CommandTag = src[:idx] dst.CommandTag = src[:idx]

View File

@ -43,19 +43,19 @@ func (dst *DataRow) Decode(src []byte) error {
return &invalidMessageFormatErr{messageType: "DataRow"} return &invalidMessageFormatErr{messageType: "DataRow"}
} }
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:]))) valueLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
rp += 4 rp += 4
// null // null
if msgSize == -1 { if valueLen == -1 {
dst.Values[i] = nil dst.Values[i] = nil
} else { } else {
if len(src[rp:]) < msgSize { if len(src[rp:]) < valueLen || valueLen < 0 {
return &invalidMessageFormatErr{messageType: "DataRow"} return &invalidMessageFormatErr{messageType: "DataRow"}
} }
dst.Values[i] = src[rp : rp+msgSize : rp+msgSize] dst.Values[i] = src[rp : rp+valueLen : rp+valueLen]
rp += msgSize rp += valueLen
} }
} }

View File

@ -4,22 +4,50 @@ import (
"bytes" "bytes"
"testing" "testing"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/pgproto3" "github.com/jackc/pgx/v5/pgproto3"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func FuzzFrontend(f *testing.F) { func FuzzFrontend(f *testing.F) {
testcases := [][]byte{ testcases := []struct {
{'Z', 0, 0, 0, 5}, msgType byte
msgLen uint32
msgBody []byte
}{
{
msgType: 'Z',
msgLen: 2,
msgBody: []byte{'I'},
},
{
msgType: 'Z',
msgLen: 5,
msgBody: []byte{'I'},
},
} }
for _, tc := range testcases { for _, tc := range testcases {
f.Add(tc) f.Add(tc.msgType, tc.msgLen, tc.msgBody)
} }
f.Fuzz(func(t *testing.T, encodedMsg []byte) { f.Fuzz(func(t *testing.T, msgType byte, msgLen uint32, msgBody []byte) {
// Prune any msgLen > len(msgBody) because they would hang the test waiting for more input.
if int(msgLen) > len(msgBody)+4 {
return
}
// Prune any messages that are too long.
if msgLen > 128 || len(msgBody) > 128 {
return
}
r := &bytes.Buffer{} r := &bytes.Buffer{}
w := &bytes.Buffer{} w := &bytes.Buffer{}
fe := pgproto3.NewFrontend(r, w) fe := pgproto3.NewFrontend(r, w)
var encodedMsg []byte
encodedMsg = append(encodedMsg, msgType)
encodedMsg = pgio.AppendUint32(encodedMsg, msgLen)
encodedMsg = append(encodedMsg, msgBody...)
_, err := r.Write(encodedMsg) _, err := r.Write(encodedMsg)
require.NoError(t, err) require.NoError(t, err)

View File

@ -22,6 +22,10 @@ func (*NotificationResponse) Backend() {}
func (dst *NotificationResponse) Decode(src []byte) error { func (dst *NotificationResponse) Decode(src []byte) error {
buf := bytes.NewBuffer(src) buf := bytes.NewBuffer(src)
if buf.Len() < 4 {
return &invalidMessageFormatErr{messageType: "NotificationResponse", details: "too short"}
}
pid := binary.BigEndian.Uint32(buf.Next(4)) pid := binary.BigEndian.Uint32(buf.Next(4))
b, err := buf.ReadBytes(0) b, err := buf.ReadBytes(0)

View File

@ -46,10 +46,11 @@ func (e *invalidMessageLenErr) Error() string {
type invalidMessageFormatErr struct { type invalidMessageFormatErr struct {
messageType string messageType string
details string
} }
func (e *invalidMessageFormatErr) Error() string { func (e *invalidMessageFormatErr) Error() string {
return fmt.Sprintf("%s body is invalid", e.messageType) return fmt.Sprintf("%s body is invalid %s", e.messageType, e.details)
} }
type writeError struct { type writeError struct {

View File

@ -0,0 +1,4 @@
go test fuzz v1
byte('A')
uint32(5)
[]byte("0")

View File

@ -1,2 +0,0 @@
go test fuzz v1
[]byte("0\x00\x00\x00\x02")

View File

@ -0,0 +1,4 @@
go test fuzz v1
byte('D')
uint32(21)
[]byte("00\xb300000000000000")

View File

@ -0,0 +1,4 @@
go test fuzz v1
byte('C')
uint32(4)
[]byte("0")

View File

@ -0,0 +1,4 @@
go test fuzz v1
byte('R')
uint32(13)
[]byte("\x00\x00\x00\n0\x12\xebG\x8dI']G\xdac\x95\xb7\x18\xb0\x02\xe8m\xc2\x00\xef\x03\x12\x1b\xbdj\x10\x9f\xf9\xeb\xb8")