mirror of https://github.com/jackc/pgx.git
CopyFrom parses strings to encode into binary format
https://github.com/jackc/pgx/issues/1277 https://github.com/jackc/pgx/issues/1267pull/1281/head
parent
02d9a5acd8
commit
7c6a31f9d2
|
@ -615,3 +615,32 @@ func TestConnCopyFromCopyFromSourceErrorEnd(t *testing.T) {
|
|||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnCopyFromAutomaticStringConversion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table foo(
|
||||
a int8
|
||||
)`)
|
||||
|
||||
inputRows := [][]interface{}{
|
||||
{"42"},
|
||||
{"7"},
|
||||
{8},
|
||||
}
|
||||
|
||||
copyCount, err := conn.CopyFrom(context.Background(), pgx.Identifier{"foo"}, []string{"a"}, pgx.CopyFromRows(inputRows))
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, len(inputRows), copyCount)
|
||||
|
||||
rows, _ := conn.Query(context.Background(), "select * from foo")
|
||||
nums, err := pgx.CollectRows(rows, pgx.RowTo[int64])
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, []int64{42, 7, 8}, nums)
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
|
24
values.go
24
values.go
|
@ -1,6 +1,8 @@
|
|||
package pgx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/internal/anynil"
|
||||
"github.com/jackc/pgx/v5/internal/pgio"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
|
@ -36,11 +38,31 @@ func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, er
|
|||
buf = pgio.AppendInt32(buf, -1)
|
||||
argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if argBuf2, err2 := tryScanStringCopyValueThenEncode(m, buf, oid, arg); err2 == nil {
|
||||
argBuf = argBuf2
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if argBuf != nil {
|
||||
buf = argBuf
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func tryScanStringCopyValueThenEncode(m *pgtype.Map, buf []byte, oid uint32, arg any) ([]byte, error) {
|
||||
s, ok := arg.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("not a string")
|
||||
}
|
||||
|
||||
var v any
|
||||
err := m.Scan(oid, TextFormatCode, []byte(s), &v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.Encode(oid, BinaryFormatCode, v, buf)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue