Simplify copy encoding

query-exec-mode
Jack Christensen 2022-03-05 20:27:36 -06:00
parent e5685a34fc
commit 2831eedef3
3 changed files with 20 additions and 36 deletions

View File

@ -178,7 +178,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b
buf = pgio.AppendInt16(buf, int16(len(ct.columnNames)))
for i, val := range values {
buf, err = encodePreparedStatementArgument(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val)
if err != nil {
return false, nil, err
}

View File

@ -211,6 +211,14 @@ func TestConnCopyFromEnum(t *testing.T) {
_, err = tx.Exec(ctx, `create type fruit as enum ('apple', 'orange', 'grape')`)
require.NoError(t, err)
// Obviously using conn while a tx is in use and registering a type after the connection has been established are
// really bad practices, but for the sake of convenience we do it in the test here.
for _, name := range []string{"fruit", "color"} {
typ, err := conn.LoadType(ctx, name)
require.NoError(t, err)
conn.TypeMap().RegisterType(typ)
}
_, err = tx.Exec(ctx, `create table foo(
a text,
b color,

View File

@ -98,46 +98,22 @@ func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error)
return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg))
}
func encodePreparedStatementArgument(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]byte, error) {
if arg == nil {
func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]byte, error) {
if anynil.Is(arg) {
return pgio.AppendInt32(buf, -1), nil
}
switch arg := arg.(type) {
case string:
buf = pgio.AppendInt32(buf, int32(len(arg)))
buf = append(buf, arg...)
return buf, nil
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf)
if err != nil {
return nil, err
}
refVal := reflect.ValueOf(arg)
if refVal.Kind() == reflect.Ptr {
if refVal.IsNil() {
return pgio.AppendInt32(buf, -1), nil
}
arg = refVal.Elem().Interface()
return encodePreparedStatementArgument(m, buf, oid, arg)
if argBuf != nil {
buf = argBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
if _, ok := m.TypeForOID(oid); ok {
sp := len(buf)
buf = pgio.AppendInt32(buf, -1)
argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf)
if err != nil {
return nil, err
}
if argBuf != nil {
buf = argBuf
pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4))
}
return buf, nil
}
if strippedArg, ok := stripNamedType(&refVal); ok {
return encodePreparedStatementArgument(m, buf, oid, strippedArg)
}
return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
return buf, nil
}
func stripNamedType(val *reflect.Value) (interface{}, bool) {