Detect unsafe pgtype.DriverBytes usage

Add test for unsafe usage and test for correct usage that ensures driver
memory is actually used.
query-exec-mode
Jack Christensen 2022-02-26 20:23:35 -06:00
parent b1e4b96e6c
commit ffc5a692cb
2 changed files with 46 additions and 6 deletions

View File

@ -36,7 +36,7 @@ func TestByteaCodec(t *testing.T) {
})
}
func TestDriverBytes(t *testing.T) {
func TestDriverBytesQueryRow(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
@ -44,14 +44,47 @@ func TestDriverBytes(t *testing.T) {
var buf []byte
err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.DriverBytes)(&buf))
require.EqualError(t, err, "cannot scan into *pgtype.DriverBytes from QueryRow")
}
func TestDriverBytes(t *testing.T) {
conn := testutil.MustConnectPgx(t)
defer testutil.MustCloseContext(t, conn)
ctx := context.Background()
argBuf := make([]byte, 128)
for i := range argBuf {
argBuf[i] = byte(i)
}
rows, err := conn.Query(ctx, `select $1::bytea from generate_series(1, 1000)`, argBuf)
require.NoError(t, err)
defer rows.Close()
require.Len(t, buf, 2)
require.Equal(t, buf, []byte{1, 2})
require.Equalf(t, cap(buf), len(buf), "cap(buf) is larger than len(buf)")
rowCount := 0
resultBuf := argBuf
detectedResultMutation := false
for rows.Next() {
rowCount++
// Don't actually have any way to be sure that the bytes are from the driver at the moment as underlying driver
// doesn't reuse buffers at the present.
// At some point the buffer should be reused and change.
if bytes.Compare(argBuf, resultBuf) != 0 {
detectedResultMutation = true
}
err = rows.Scan((*pgtype.DriverBytes)(&resultBuf))
require.NoError(t, err)
require.Len(t, resultBuf, len(argBuf))
require.Equal(t, resultBuf, argBuf)
require.Equalf(t, cap(resultBuf), len(resultBuf), "cap(resultBuf) is larger than len(resultBuf)")
}
require.True(t, detectedResultMutation)
err = rows.Err()
require.NoError(t, err)
}
func TestPreallocBytes(t *testing.T) {

View File

@ -79,6 +79,13 @@ func (r *connRow) Scan(dest ...interface{}) (err error) {
return rows.Err()
}
for _, d := range dest {
if _, ok := d.(*pgtype.DriverBytes); ok {
rows.Close()
return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow")
}
}
if !rows.Next() {
if rows.Err() == nil {
return ErrNoRows