mirror of https://github.com/jackc/pgx.git
Better fuzz testing and fix several bugs it found
Fix infinite loop in AuthenticationSASL.Decode Fix panic in CommandComplete.Decode Fix panic in DataRow.Decode Fix panic in NotificationResponse.Decodepull/1281/head
parent
9d0f27bc4b
commit
7f382f5190
|
@ -36,10 +36,11 @@ func (dst *AuthenticationSASL) Decode(src []byte) error {
|
|||
authMechanisms := src[4:]
|
||||
for len(authMechanisms) > 1 {
|
||||
idx := bytes.IndexByte(authMechanisms, 0)
|
||||
if idx > 0 {
|
||||
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
|
||||
authMechanisms = authMechanisms[idx+1:]
|
||||
if idx == -1 {
|
||||
return &invalidMessageFormatErr{messageType: "AuthenticationSASL", details: "unterminated string"}
|
||||
}
|
||||
dst.AuthMechanisms = append(dst.AuthMechanisms, string(authMechanisms[:idx]))
|
||||
authMechanisms = authMechanisms[idx+1:]
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
@ -18,8 +18,11 @@ func (*CommandComplete) Backend() {}
|
|||
// type identifier and 4 byte message length.
|
||||
func (dst *CommandComplete) Decode(src []byte) error {
|
||||
idx := bytes.IndexByte(src, 0)
|
||||
if idx == -1 {
|
||||
return &invalidMessageFormatErr{messageType: "CommandComplete", details: "unterminated string"}
|
||||
}
|
||||
if idx != len(src)-1 {
|
||||
return &invalidMessageFormatErr{messageType: "CommandComplete"}
|
||||
return &invalidMessageFormatErr{messageType: "CommandComplete", details: "string terminated too early"}
|
||||
}
|
||||
|
||||
dst.CommandTag = src[:idx]
|
||||
|
|
|
@ -43,19 +43,19 @@ func (dst *DataRow) Decode(src []byte) error {
|
|||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
|
||||
msgSize := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
valueLen := int(int32(binary.BigEndian.Uint32(src[rp:])))
|
||||
rp += 4
|
||||
|
||||
// null
|
||||
if msgSize == -1 {
|
||||
if valueLen == -1 {
|
||||
dst.Values[i] = nil
|
||||
} else {
|
||||
if len(src[rp:]) < msgSize {
|
||||
if len(src[rp:]) < valueLen || valueLen < 0 {
|
||||
return &invalidMessageFormatErr{messageType: "DataRow"}
|
||||
}
|
||||
|
||||
dst.Values[i] = src[rp : rp+msgSize : rp+msgSize]
|
||||
rp += msgSize
|
||||
dst.Values[i] = src[rp : rp+valueLen : rp+valueLen]
|
||||
rp += valueLen
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,22 +4,50 @@ import (
|
|||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
"github.com/jackc/pgx/v5/pgproto3"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func FuzzFrontend(f *testing.F) {
|
||||
testcases := [][]byte{
|
||||
{'Z', 0, 0, 0, 5},
|
||||
testcases := []struct {
|
||||
msgType byte
|
||||
msgLen uint32
|
||||
msgBody []byte
|
||||
}{
|
||||
{
|
||||
msgType: 'Z',
|
||||
msgLen: 2,
|
||||
msgBody: []byte{'I'},
|
||||
},
|
||||
{
|
||||
msgType: 'Z',
|
||||
msgLen: 5,
|
||||
msgBody: []byte{'I'},
|
||||
},
|
||||
}
|
||||
for _, tc := range testcases {
|
||||
f.Add(tc)
|
||||
f.Add(tc.msgType, tc.msgLen, tc.msgBody)
|
||||
}
|
||||
f.Fuzz(func(t *testing.T, encodedMsg []byte) {
|
||||
f.Fuzz(func(t *testing.T, msgType byte, msgLen uint32, msgBody []byte) {
|
||||
// Prune any msgLen > len(msgBody) because they would hang the test waiting for more input.
|
||||
if int(msgLen) > len(msgBody)+4 {
|
||||
return
|
||||
}
|
||||
|
||||
// Prune any messages that are too long.
|
||||
if msgLen > 128 || len(msgBody) > 128 {
|
||||
return
|
||||
}
|
||||
|
||||
r := &bytes.Buffer{}
|
||||
w := &bytes.Buffer{}
|
||||
fe := pgproto3.NewFrontend(r, w)
|
||||
|
||||
var encodedMsg []byte
|
||||
encodedMsg = append(encodedMsg, msgType)
|
||||
encodedMsg = pgio.AppendUint32(encodedMsg, msgLen)
|
||||
encodedMsg = append(encodedMsg, msgBody...)
|
||||
_, err := r.Write(encodedMsg)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
@ -22,6 +22,10 @@ func (*NotificationResponse) Backend() {}
|
|||
func (dst *NotificationResponse) Decode(src []byte) error {
|
||||
buf := bytes.NewBuffer(src)
|
||||
|
||||
if buf.Len() < 4 {
|
||||
return &invalidMessageFormatErr{messageType: "NotificationResponse", details: "too short"}
|
||||
}
|
||||
|
||||
pid := binary.BigEndian.Uint32(buf.Next(4))
|
||||
|
||||
b, err := buf.ReadBytes(0)
|
||||
|
|
|
@ -46,10 +46,11 @@ func (e *invalidMessageLenErr) Error() string {
|
|||
|
||||
type invalidMessageFormatErr struct {
|
||||
messageType string
|
||||
details string
|
||||
}
|
||||
|
||||
func (e *invalidMessageFormatErr) Error() string {
|
||||
return fmt.Sprintf("%s body is invalid", e.messageType)
|
||||
return fmt.Sprintf("%s body is invalid %s", e.messageType, e.details)
|
||||
}
|
||||
|
||||
type writeError struct {
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
go test fuzz v1
|
||||
byte('A')
|
||||
uint32(5)
|
||||
[]byte("0")
|
|
@ -1,2 +0,0 @@
|
|||
go test fuzz v1
|
||||
[]byte("0\x00\x00\x00\x02")
|
|
@ -0,0 +1,4 @@
|
|||
go test fuzz v1
|
||||
byte('D')
|
||||
uint32(21)
|
||||
[]byte("00\xb300000000000000")
|
|
@ -0,0 +1,4 @@
|
|||
go test fuzz v1
|
||||
byte('C')
|
||||
uint32(4)
|
||||
[]byte("0")
|
|
@ -0,0 +1,4 @@
|
|||
go test fuzz v1
|
||||
byte('R')
|
||||
uint32(13)
|
||||
[]byte("\x00\x00\x00\n0\x12\xebG\x8dI']G\xdac\x95\xb7\x18\xb0\x02\xe8m\xc2\x00\xef\x03\x12\x1b\xbdj\x10\x9f\xf9\xeb\xb8")
|
Loading…
Reference in New Issue