mirror of https://github.com/jackc/pgx.git
Detect unsafe pgtype.DriverBytes usage
Add test for unsafe usage and test for correct usage that ensures driver memory is actually used.query-exec-mode
parent
b1e4b96e6c
commit
ffc5a692cb
|
@ -36,7 +36,7 @@ func TestByteaCodec(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestDriverBytes(t *testing.T) {
|
||||
func TestDriverBytesQueryRow(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
|
@ -44,14 +44,47 @@ func TestDriverBytes(t *testing.T) {
|
|||
|
||||
var buf []byte
|
||||
err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.DriverBytes)(&buf))
|
||||
require.EqualError(t, err, "cannot scan into *pgtype.DriverBytes from QueryRow")
|
||||
}
|
||||
|
||||
func TestDriverBytes(t *testing.T) {
|
||||
conn := testutil.MustConnectPgx(t)
|
||||
defer testutil.MustCloseContext(t, conn)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
argBuf := make([]byte, 128)
|
||||
for i := range argBuf {
|
||||
argBuf[i] = byte(i)
|
||||
}
|
||||
|
||||
rows, err := conn.Query(ctx, `select $1::bytea from generate_series(1, 1000)`, argBuf)
|
||||
require.NoError(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
require.Len(t, buf, 2)
|
||||
require.Equal(t, buf, []byte{1, 2})
|
||||
require.Equalf(t, cap(buf), len(buf), "cap(buf) is larger than len(buf)")
|
||||
rowCount := 0
|
||||
resultBuf := argBuf
|
||||
detectedResultMutation := false
|
||||
for rows.Next() {
|
||||
rowCount++
|
||||
|
||||
// Don't actually have any way to be sure that the bytes are from the driver at the moment as underlying driver
|
||||
// doesn't reuse buffers at the present.
|
||||
// At some point the buffer should be reused and change.
|
||||
if bytes.Compare(argBuf, resultBuf) != 0 {
|
||||
detectedResultMutation = true
|
||||
}
|
||||
|
||||
err = rows.Scan((*pgtype.DriverBytes)(&resultBuf))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Len(t, resultBuf, len(argBuf))
|
||||
require.Equal(t, resultBuf, argBuf)
|
||||
require.Equalf(t, cap(resultBuf), len(resultBuf), "cap(resultBuf) is larger than len(resultBuf)")
|
||||
}
|
||||
|
||||
require.True(t, detectedResultMutation)
|
||||
|
||||
err = rows.Err()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestPreallocBytes(t *testing.T) {
|
||||
|
|
7
rows.go
7
rows.go
|
@ -79,6 +79,13 @@ func (r *connRow) Scan(dest ...interface{}) (err error) {
|
|||
return rows.Err()
|
||||
}
|
||||
|
||||
for _, d := range dest {
|
||||
if _, ok := d.(*pgtype.DriverBytes); ok {
|
||||
rows.Close()
|
||||
return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow")
|
||||
}
|
||||
}
|
||||
|
||||
if !rows.Next() {
|
||||
if rows.Err() == nil {
|
||||
return ErrNoRows
|
||||
|
|
Loading…
Reference in New Issue