mirror of https://github.com/jackc/pgx.git
Detect unsafe pgtype.DriverBytes usage
Add test for unsafe usage and test for correct usage that ensures driver memory is actually used.query-exec-mode
parent
b1e4b96e6c
commit
ffc5a692cb
|
@ -36,7 +36,7 @@ func TestByteaCodec(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDriverBytes(t *testing.T) {
|
func TestDriverBytesQueryRow(t *testing.T) {
|
||||||
conn := testutil.MustConnectPgx(t)
|
conn := testutil.MustConnectPgx(t)
|
||||||
defer testutil.MustCloseContext(t, conn)
|
defer testutil.MustCloseContext(t, conn)
|
||||||
|
|
||||||
|
@ -44,14 +44,47 @@ func TestDriverBytes(t *testing.T) {
|
||||||
|
|
||||||
var buf []byte
|
var buf []byte
|
||||||
err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.DriverBytes)(&buf))
|
err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.DriverBytes)(&buf))
|
||||||
|
require.EqualError(t, err, "cannot scan into *pgtype.DriverBytes from QueryRow")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDriverBytes(t *testing.T) {
|
||||||
|
conn := testutil.MustConnectPgx(t)
|
||||||
|
defer testutil.MustCloseContext(t, conn)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
argBuf := make([]byte, 128)
|
||||||
|
for i := range argBuf {
|
||||||
|
argBuf[i] = byte(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := conn.Query(ctx, `select $1::bytea from generate_series(1, 1000)`, argBuf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
rowCount := 0
|
||||||
|
resultBuf := argBuf
|
||||||
|
detectedResultMutation := false
|
||||||
|
for rows.Next() {
|
||||||
|
rowCount++
|
||||||
|
|
||||||
|
// At some point the buffer should be reused and change.
|
||||||
|
if bytes.Compare(argBuf, resultBuf) != 0 {
|
||||||
|
detectedResultMutation = true
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rows.Scan((*pgtype.DriverBytes)(&resultBuf))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
require.Len(t, buf, 2)
|
require.Len(t, resultBuf, len(argBuf))
|
||||||
require.Equal(t, buf, []byte{1, 2})
|
require.Equal(t, resultBuf, argBuf)
|
||||||
require.Equalf(t, cap(buf), len(buf), "cap(buf) is larger than len(buf)")
|
require.Equalf(t, cap(resultBuf), len(resultBuf), "cap(resultBuf) is larger than len(resultBuf)")
|
||||||
|
}
|
||||||
|
|
||||||
// Don't actually have any way to be sure that the bytes are from the driver at the moment as underlying driver
|
require.True(t, detectedResultMutation)
|
||||||
// doesn't reuse buffers at the present.
|
|
||||||
|
err = rows.Err()
|
||||||
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestPreallocBytes(t *testing.T) {
|
func TestPreallocBytes(t *testing.T) {
|
||||||
|
|
7
rows.go
7
rows.go
|
@ -79,6 +79,13 @@ func (r *connRow) Scan(dest ...interface{}) (err error) {
|
||||||
return rows.Err()
|
return rows.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, d := range dest {
|
||||||
|
if _, ok := d.(*pgtype.DriverBytes); ok {
|
||||||
|
rows.Close()
|
||||||
|
return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !rows.Next() {
|
if !rows.Next() {
|
||||||
if rows.Err() == nil {
|
if rows.Err() == nil {
|
||||||
return ErrNoRows
|
return ErrNoRows
|
||||||
|
|
Loading…
Reference in New Issue