mirror of https://github.com/jackc/pgx.git
Simplify copy encoding
parent
e5685a34fc
commit
2831eedef3
|
@ -178,7 +178,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
|
|||
|
||||
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
|
||||
for i, val := range values {
|
||||
buf, err = encodePreparedStatementArgument(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
|
||||
buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
|
||||
if err != nil {
|
||||
return false, nil, err
|
||||
}
|
||||
|
|
|
@ -211,6 +211,14 @@ func TestConnCopyFromEnum(t *testing.T) {
|
|||
_, err = tx.Exec(ctx, `create type fruit as enum ('apple', 'orange', 'grape')`)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Obviously using conn while a tx is in use and registering a type after the connection has been established are
|
||||
// really bad practices, but for the sake of convenience we do it in the test here.
|
||||
for _, name := range []string{"fruit", "color"} {
|
||||
typ, err := conn.LoadType(ctx, name)
|
||||
require.NoError(t, err)
|
||||
conn.TypeMap().RegisterType(typ)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `create table foo(
|
||||
a text,
|
||||
b color,
|
||||
|
|
46
values.go
46
values.go
|
@ -98,46 +98,22 @@ func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error)
|
|||
return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg))
|
||||
}
|
||||
|
||||
func encodePreparedStatementArgument(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]byte, error) {
|
||||
if arg == nil {
|
||||
func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]byte, error) {
|
||||
if anynil.Is(arg) {
|
||||
return pgio.AppendInt32(buf, -1), nil
|
||||
}
|
||||
|
||||
switch arg := arg.(type) {
|
||||
case string:
|
||||
buf = pgio.AppendInt32(buf, int32(len(arg)))
|
||||
buf = append(buf, arg...)
|
||||
return buf, nil
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
refVal := reflect.ValueOf(arg)
|
||||
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
if refVal.IsNil() {
|
||||
return pgio.AppendInt32(buf, -1), nil
|
||||
}
|
||||
arg = refVal.Elem().Interface()
|
||||
return encodePreparedStatementArgument(m, buf, oid, arg)
|
||||
if argBuf != nil {
|
||||
buf = argBuf
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
||||
}
|
||||
|
||||
if _, ok := m.TypeForOID(oid); ok {
|
||||
sp := len(buf)
|
||||
buf = pgio.AppendInt32(buf, -1)
|
||||
argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if argBuf != nil {
|
||||
buf = argBuf
|
||||
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
if strippedArg, ok := stripNamedType(&refVal); ok {
|
||||
return encodePreparedStatementArgument(m, buf, oid, strippedArg)
|
||||
}
|
||||
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func stripNamedType(val *reflect.Value) (interface{}, bool) {
|
||||
|
|
Loading…
Reference in New Issue