mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
As recommended by go-staticcheck, but also might be a bit more efficient for the compiler to implement, since we don't care about which slice of bytes is greater than the other one.
138 lines
3.8 KiB
Go
138 lines
3.8 KiB
Go
package pgtype_test
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"testing"
|
|
|
|
pgx "github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
"github.com/jackc/pgx/v5/pgxtest"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func isExpectedEqBytes(a any) func(any) bool {
|
|
return func(v any) bool {
|
|
ab := a.([]byte)
|
|
vb := v.([]byte)
|
|
|
|
if (ab == nil) != (vb == nil) {
|
|
return false
|
|
}
|
|
|
|
if ab == nil {
|
|
return true
|
|
}
|
|
|
|
return bytes.Equal(ab, vb)
|
|
}
|
|
}
|
|
|
|
func TestByteaCodec(t *testing.T) {
|
|
pgxtest.RunValueRoundTripTests(context.Background(), t, defaultConnTestRunner, nil, "bytea", []pgxtest.ValueRoundTripTest{
|
|
{[]byte{1, 2, 3}, new([]byte), isExpectedEqBytes([]byte{1, 2, 3})},
|
|
{[]byte{}, new([]byte), isExpectedEqBytes([]byte{})},
|
|
{[]byte(nil), new([]byte), isExpectedEqBytes([]byte(nil))},
|
|
{nil, new([]byte), isExpectedEqBytes([]byte(nil))},
|
|
})
|
|
}
|
|
|
|
func TestDriverBytesQueryRow(t *testing.T) {
|
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
|
var buf []byte
|
|
err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.DriverBytes)(&buf))
|
|
require.EqualError(t, err, "cannot scan into *pgtype.DriverBytes from QueryRow")
|
|
})
|
|
}
|
|
|
|
func TestDriverBytes(t *testing.T) {
|
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
|
argBuf := make([]byte, 128)
|
|
for i := range argBuf {
|
|
argBuf[i] = byte(i)
|
|
}
|
|
|
|
rows, err := conn.Query(ctx, `select $1::bytea from generate_series(1, 1000)`, argBuf)
|
|
require.NoError(t, err)
|
|
defer rows.Close()
|
|
|
|
rowCount := 0
|
|
resultBuf := argBuf
|
|
detectedResultMutation := false
|
|
for rows.Next() {
|
|
rowCount++
|
|
|
|
// At some point the buffer should be reused and change.
|
|
if !bytes.Equal(argBuf, resultBuf) {
|
|
detectedResultMutation = true
|
|
}
|
|
|
|
err = rows.Scan((*pgtype.DriverBytes)(&resultBuf))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, resultBuf, len(argBuf))
|
|
require.Equal(t, resultBuf, argBuf)
|
|
require.Equalf(t, cap(resultBuf), len(resultBuf), "cap(resultBuf) is larger than len(resultBuf)")
|
|
}
|
|
|
|
require.True(t, detectedResultMutation)
|
|
|
|
err = rows.Err()
|
|
require.NoError(t, err)
|
|
})
|
|
}
|
|
|
|
func TestPreallocBytes(t *testing.T) {
|
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
|
origBuf := []byte{5, 6, 7, 8}
|
|
buf := origBuf
|
|
err := conn.QueryRow(ctx, `select $1::bytea`, []byte{1, 2}).Scan((*pgtype.PreallocBytes)(&buf))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, buf, 2)
|
|
require.Equal(t, 4, cap(buf))
|
|
require.Equal(t, buf, []byte{1, 2})
|
|
|
|
require.Equal(t, []byte{1, 2, 7, 8}, origBuf)
|
|
|
|
err = conn.QueryRow(ctx, `select $1::bytea`, []byte{3, 4, 5, 6, 7}).Scan((*pgtype.PreallocBytes)(&buf))
|
|
require.NoError(t, err)
|
|
require.Len(t, buf, 5)
|
|
require.Equal(t, 5, cap(buf))
|
|
|
|
require.Equal(t, []byte{1, 2, 7, 8}, origBuf)
|
|
})
|
|
}
|
|
|
|
func TestUndecodedBytes(t *testing.T) {
|
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
|
var buf []byte
|
|
err := conn.QueryRow(ctx, `select 1::int4`).Scan((*pgtype.UndecodedBytes)(&buf))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, buf, 4)
|
|
require.Equal(t, buf, []byte{0, 0, 0, 1})
|
|
})
|
|
}
|
|
|
|
func TestByteaCodecDecodeDatabaseSQLValue(t *testing.T) {
|
|
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
|
|
var buf []byte
|
|
err := conn.QueryRow(ctx, `select '\xa1b2c3d4'::bytea`).Scan(sqlScannerFunc(func(src any) error {
|
|
switch src := src.(type) {
|
|
case []byte:
|
|
buf = make([]byte, len(src))
|
|
copy(buf, src)
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("expected []byte, got %T", src)
|
|
}
|
|
}))
|
|
require.NoError(t, err)
|
|
|
|
require.Len(t, buf, 4)
|
|
require.Equal(t, buf, []byte{0xa1, 0xb2, 0xc3, 0xd4})
|
|
})
|
|
}
|