mirror of https://github.com/jackc/pgx.git
Refactor encoding parameters for prepared statements
parent
9e289cb186
commit
a636ef31a4
24
conn.go
24
conn.go
|
@ -976,32 +976,12 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
|
||||
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
|
||||
for i, oid := range ps.ParameterOids {
|
||||
switch arguments[i].(type) {
|
||||
case pgtype.BinaryEncoder:
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
case pgtype.TextEncoder:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
case string, *string:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
default:
|
||||
if dt, ok := c.ConnInfo.DataTypeForOid(oid); ok {
|
||||
switch dt.Value.(type) {
|
||||
case pgtype.BinaryEncoder:
|
||||
wbuf.WriteInt16(BinaryFormatCode)
|
||||
case pgtype.TextEncoder:
|
||||
wbuf.WriteInt16(TextFormatCode)
|
||||
default:
|
||||
return fmt.Errorf("value for oid %v does not implement pgtype.BinaryEncoder or pgtype.TextEncoder", oid)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("unknown type for oid %v", oid)
|
||||
}
|
||||
}
|
||||
wbuf.WriteInt16(chooseParameterFormatCode(c.ConnInfo, oid, arguments[i]))
|
||||
}
|
||||
|
||||
wbuf.WriteInt16(int16(len(arguments)))
|
||||
for i, oid := range ps.ParameterOids {
|
||||
if err := Encode(wbuf, oid, arguments[i]); err != nil {
|
||||
if err := encodePreparedStatementArgument(wbuf, oid, arguments[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
|
@ -157,7 +157,7 @@ func (ct *copyFrom) run() (int, error) {
|
|||
|
||||
wbuf.WriteInt16(int16(len(ct.columnNames)))
|
||||
for i, val := range values {
|
||||
err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val)
|
||||
err = encodePreparedStatementArgument(wbuf, ps.FieldDescriptions[i].DataType, val)
|
||||
if err != nil {
|
||||
ct.cancelCopyIn()
|
||||
return 0, err
|
||||
|
|
31
values.go
31
values.go
|
@ -71,10 +71,7 @@ func (e SerializationError) Error() string {
|
|||
return string(e)
|
||||
}
|
||||
|
||||
// Encode encodes arg into wbuf as the type oid. This allows implementations
|
||||
// of the Encoder interface to delegate the actual work of encoding to the
|
||||
// built-in functionality.
|
||||
func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
||||
func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
||||
if arg == nil {
|
||||
wbuf.WriteInt32(-1)
|
||||
return nil
|
||||
|
@ -112,7 +109,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return Encode(wbuf, oid, v)
|
||||
return encodePreparedStatementArgument(wbuf, oid, v)
|
||||
case string:
|
||||
return encodeString(wbuf, oid, arg)
|
||||
case []byte:
|
||||
|
@ -127,7 +124,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
|||
return nil
|
||||
}
|
||||
arg = refVal.Elem().Interface()
|
||||
return Encode(wbuf, oid, arg)
|
||||
return encodePreparedStatementArgument(wbuf, oid, arg)
|
||||
}
|
||||
|
||||
if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok {
|
||||
|
@ -152,11 +149,31 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
|||
}
|
||||
|
||||
if strippedArg, ok := stripNamedType(&refVal); ok {
|
||||
return Encode(wbuf, oid, strippedArg)
|
||||
return encodePreparedStatementArgument(wbuf, oid, strippedArg)
|
||||
}
|
||||
return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
|
||||
}
|
||||
|
||||
// chooseParameterFormatCode determines the correct format code for an
|
||||
// argument to a prepared statement. It defaults to TextFormatCode if no
|
||||
// determination can be made.
|
||||
func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.Oid, arg interface{}) int16 {
|
||||
switch arg.(type) {
|
||||
case pgtype.BinaryEncoder:
|
||||
return BinaryFormatCode
|
||||
case string, *string, pgtype.TextEncoder:
|
||||
return TextFormatCode
|
||||
}
|
||||
|
||||
if dt, ok := ci.DataTypeForOid(oid); ok {
|
||||
if _, ok := dt.Value.(pgtype.BinaryEncoder); ok {
|
||||
return BinaryFormatCode
|
||||
}
|
||||
}
|
||||
|
||||
return TextFormatCode
|
||||
}
|
||||
|
||||
func stripNamedType(val *reflect.Value) (interface{}, bool) {
|
||||
switch val.Kind() {
|
||||
case reflect.Int:
|
||||
|
|
Loading…
Reference in New Issue