diff --git a/copy_from_test.go b/copy_from_test.go index d979d2dc..49bfcb34 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -615,3 +615,32 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) { ensureConnValid(t, conn) } + +func TestConnCopyFromAutomaticStringConversion(t *testing.T) { + t.Parallel() + + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table foo( + a int8 + )`) + + inputRows := [][]interface{}{ + {"42"}, + {"7"}, + {8}, + } + + copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows)) + require.NoError(t, err) + require.EqualValues(t, len(inputRows), copyCount) + + rows, _ := conn.Query(context.Background(), "select * from foo") + nums, err := pgx.CollectRows(rows, pgx.RowTo[int64]) + require.NoError(t, err) + + require.Equal(t, []int64{42, 7, 8}, nums) + + ensureConnValid(t, conn) +} diff --git a/values.go b/values.go index d27e071d..19c642fa 100644 --- a/values.go +++ b/values.go @@ -1,6 +1,8 @@ package pgx import ( + "errors" + "github.com/jackc/pgx/v5/internal/anynil" "github.com/jackc/pgx/v5/internal/pgio" "github.com/jackc/pgx/v5/pgtype" @@ -36,11 +38,31 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, er buf = pgio.AppendInt32(buf, -1) argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) if err != nil { - return nil, err + if argBuf2, err2 := tryScanStringCopyValueThenEncode(m, buf, oid, arg); err2 == nil { + argBuf = argBuf2 + } else { + return nil, err + } } + if argBuf != nil { buf = argBuf pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } return buf, nil } + +func tryScanStringCopyValueThenEncode(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) { + s, ok := arg.(string) + if !ok { + return nil, errors.New("not a string") + } + + var v any + err := m.Scan(oid, TextFormatCode, []byte(s), &v) + if err != nil { + return nil, err + } + + return m.Encode(oid, BinaryFormatCode, v, buf) +}