diff --git a/conn.go b/conn.go index bdb229a9..6c6998b5 100644 --- a/conn.go +++ b/conn.go @@ -976,32 +976,12 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} wbuf.WriteInt16(int16(len(ps.ParameterOids))) for i, oid := range ps.ParameterOids { - switch arguments[i].(type) { - case pgtype.BinaryEncoder: - wbuf.WriteInt16(BinaryFormatCode) - case pgtype.TextEncoder: - wbuf.WriteInt16(TextFormatCode) - case string, *string: - wbuf.WriteInt16(TextFormatCode) - default: - if dt, ok := c.ConnInfo.DataTypeForOid(oid); ok { - switch dt.Value.(type) { - case pgtype.BinaryEncoder: - wbuf.WriteInt16(BinaryFormatCode) - case pgtype.TextEncoder: - wbuf.WriteInt16(TextFormatCode) - default: - return fmt.Errorf("value for oid %v does not implement pgtype.BinaryEncoder or pgtype.TextEncoder", oid) - } - } else { - return fmt.Errorf("unknown type for oid %v", oid) - } - } + wbuf.WriteInt16(chooseParameterFormatCode(c.ConnInfo, oid, arguments[i])) } wbuf.WriteInt16(int16(len(arguments))) for i, oid := range ps.ParameterOids { - if err := Encode(wbuf, oid, arguments[i]); err != nil { + if err := encodePreparedStatementArgument(wbuf, oid, arguments[i]); err != nil { return err } } diff --git a/copy_from.go b/copy_from.go index 1f8a2306..9fc76a7b 100644 --- a/copy_from.go +++ b/copy_from.go @@ -157,7 +157,7 @@ func (ct *copyFrom) run() (int, error) { wbuf.WriteInt16(int16(len(ct.columnNames))) for i, val := range values { - err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val) + err = encodePreparedStatementArgument(wbuf, ps.FieldDescriptions[i].DataType, val) if err != nil { ct.cancelCopyIn() return 0, err diff --git a/values.go b/values.go index aec3cda7..49df5d89 100644 --- a/values.go +++ b/values.go @@ -71,10 +71,7 @@ func (e SerializationError) Error() string { return string(e) } -// Encode encodes arg into wbuf as the type oid. This allows implementations -// of the Encoder interface to delegate the actual work of encoding to the -// built-in functionality. -func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { +func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { if arg == nil { wbuf.WriteInt32(-1) return nil @@ -112,7 +109,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { if err != nil { return err } - return Encode(wbuf, oid, v) + return encodePreparedStatementArgument(wbuf, oid, v) case string: return encodeString(wbuf, oid, arg) case []byte: @@ -127,7 +124,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { return nil } arg = refVal.Elem().Interface() - return Encode(wbuf, oid, arg) + return encodePreparedStatementArgument(wbuf, oid, arg) } if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok { @@ -152,11 +149,31 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error { } if strippedArg, ok := stripNamedType(&refVal); ok { - return Encode(wbuf, oid, strippedArg) + return encodePreparedStatementArgument(wbuf, oid, strippedArg) } return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg)) } +// chooseParameterFormatCode determines the correct format code for an +// argument to a prepared statement. It defaults to TextFormatCode if no +// determination can be made. +func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.Oid, arg interface{}) int16 { + switch arg.(type) { + case pgtype.BinaryEncoder: + return BinaryFormatCode + case string, *string, pgtype.TextEncoder: + return TextFormatCode + } + + if dt, ok := ci.DataTypeForOid(oid); ok { + if _, ok := dt.Value.(pgtype.BinaryEncoder); ok { + return BinaryFormatCode + } + } + + return TextFormatCode +} + func stripNamedType(val *reflect.Value) (interface{}, bool) { switch val.Kind() { case reflect.Int: