diff --git a/copy_from.go b/copy_from.go index 7d6a8813..ef982269 100644 --- a/copy_from.go +++ b/copy_from.go @@ -178,7 +178,7 @@ func (ct *copyFrom) buildCopyBuf(buf []byte, sd *pgconn.StatementDescription) (b buf = pgio.AppendInt16(buf, int16(len(ct.columnNames))) for i, val := range values { - buf, err = encodePreparedStatementArgument(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) + buf, err = encodeCopyValue(ct.conn.typeMap, buf, sd.Fields[i].DataTypeOID, val) if err != nil { return false, nil, err } diff --git a/copy_from_test.go b/copy_from_test.go index 5c22dc35..6e2fe952 100644 --- a/copy_from_test.go +++ b/copy_from_test.go @@ -211,6 +211,14 @@ func TestConnCopyFromEnum(t *testing.T) { _, err = tx.Exec(ctx, `create type fruit as enum ('apple', 'orange', 'grape')`) require.NoError(t, err) + // Obviously using conn while a tx is in use and registering a type after the connection has been established are + // really bad practices, but for the sake of convenience we do it in the test here. + for _, name := range []string{"fruit", "color"} { + typ, err := conn.LoadType(ctx, name) + require.NoError(t, err) + conn.TypeMap().RegisterType(typ) + } + _, err = tx.Exec(ctx, `create table foo( a text, b color, diff --git a/values.go b/values.go index fe7f6444..766074bd 100644 --- a/values.go +++ b/values.go @@ -98,46 +98,22 @@ func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error) return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg)) } -func encodePreparedStatementArgument(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]byte, error) { - if arg == nil { +func encodeCopyValue(m *pgtype.Map, buf []byte, oid uint32, arg interface{}) ([]byte, error) { + if anynil.Is(arg) { return pgio.AppendInt32(buf, -1), nil } - switch arg := arg.(type) { - case string: - buf = pgio.AppendInt32(buf, int32(len(arg))) - buf = append(buf, arg...) - return buf, nil + sp := len(buf) + buf = pgio.AppendInt32(buf, -1) + argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) + if err != nil { + return nil, err } - - refVal := reflect.ValueOf(arg) - - if refVal.Kind() == reflect.Ptr { - if refVal.IsNil() { - return pgio.AppendInt32(buf, -1), nil - } - arg = refVal.Elem().Interface() - return encodePreparedStatementArgument(m, buf, oid, arg) + if argBuf != nil { + buf = argBuf + pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) } - - if _, ok := m.TypeForOID(oid); ok { - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - argBuf, err := m.Encode(oid, BinaryFormatCode, arg, buf) - if err != nil { - return nil, err - } - if argBuf != nil { - buf = argBuf - pgio.SetInt32(buf[sp:], int32(len(buf[sp:])-4)) - } - return buf, nil - } - - if strippedArg, ok := stripNamedType(&refVal); ok { - return encodePreparedStatementArgument(m, buf, oid, strippedArg) - } - return nil, SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) + return buf, nil } func stripNamedType(val *reflect.Value) (interface{}, bool) {