diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go index 41c3482e..ae4a8760 100644 --- a/pgtype/bytea_test.go +++ b/pgtype/bytea_test.go @@ -36,7 +36,7 @@ func TestByteaCodec(t *testing.T) { }) } -func TestDriverBytes(t *testing.T) { +func TestDriverBytesQueryRow(t *testing.T) { conn := testutil.MustConnectPgx(t) defer testutil.MustCloseContext(t, conn) @@ -44,14 +44,47 @@ func TestDriverBytes(t *testing.T) { var buf []byte err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.DriverBytes)(&buf)) + require.EqualError(t, err, "cannot scan into *pgtype.DriverBytes from QueryRow") +} + +func TestDriverBytes(t *testing.T) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + ctx := context.Background() + + argBuf := make([]byte, 128) + for i := range argBuf { + argBuf[i] = byte(i) + } + + rows, err := conn.Query(ctx, `select $1::bytea from generate_series(1, 1000)`, argBuf) require.NoError(t, err) + defer rows.Close() - require.Len(t, buf, 2) - require.Equal(t, buf, []byte{1, 2}) - require.Equalf(t, cap(buf), len(buf), "cap(buf) is larger than len(buf)") + rowCount := 0 + resultBuf := argBuf + detectedResultMutation := false + for rows.Next() { + rowCount++ - // Don't actually have any way to be sure that the bytes are from the driver at the moment as underlying driver - // doesn't reuse buffers at the present. + // At some point the buffer should be reused and change. + if bytes.Compare(argBuf, resultBuf) != 0 { + detectedResultMutation = true + } + + err = rows.Scan((*pgtype.DriverBytes)(&resultBuf)) + require.NoError(t, err) + + require.Len(t, resultBuf, len(argBuf)) + require.Equal(t, resultBuf, argBuf) + require.Equalf(t, cap(resultBuf), len(resultBuf), "cap(resultBuf) is larger than len(resultBuf)") + } + + require.True(t, detectedResultMutation) + + err = rows.Err() + require.NoError(t, err) } func TestPreallocBytes(t *testing.T) { diff --git a/rows.go b/rows.go index d9b155e6..3ff8c93e 100644 --- a/rows.go +++ b/rows.go @@ -79,6 +79,13 @@ func (r *connRow) Scan(dest ...interface{}) (err error) { return rows.Err() } + for _, d := range dest { + if _, ok := d.(*pgtype.DriverBytes); ok { + rows.Close() + return fmt.Errorf("cannot scan into *pgtype.DriverBytes from QueryRow") + } + } + if !rows.Next() { if rows.Err() == nil { return ErrNoRows