Refactor encoding parameters for prepared statements

v3-numeric-wip
Jack Christensen 2017-03-18 14:23:04 -05:00
parent 9e289cb186
commit a636ef31a4
3 changed files with 27 additions and 30 deletions

24
conn.go
View File

@ -976,32 +976,12 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
for i, oid := range ps.ParameterOids {
switch arguments[i].(type) {
case pgtype.BinaryEncoder:
wbuf.WriteInt16(BinaryFormatCode)
case pgtype.TextEncoder:
wbuf.WriteInt16(TextFormatCode)
case string, *string:
wbuf.WriteInt16(TextFormatCode)
default:
if dt, ok := c.ConnInfo.DataTypeForOid(oid); ok {
switch dt.Value.(type) {
case pgtype.BinaryEncoder:
wbuf.WriteInt16(BinaryFormatCode)
case pgtype.TextEncoder:
wbuf.WriteInt16(TextFormatCode)
default:
return fmt.Errorf("value for oid %v does not implement pgtype.BinaryEncoder or pgtype.TextEncoder", oid)
}
} else {
return fmt.Errorf("unknown type for oid %v", oid)
}
}
wbuf.WriteInt16(chooseParameterFormatCode(c.ConnInfo, oid, arguments[i]))
}
wbuf.WriteInt16(int16(len(arguments)))
for i, oid := range ps.ParameterOids {
if err := Encode(wbuf, oid, arguments[i]); err != nil {
if err := encodePreparedStatementArgument(wbuf, oid, arguments[i]); err != nil {
return err
}
}

View File

@ -157,7 +157,7 @@ func (ct *copyFrom) run() (int, error) {
wbuf.WriteInt16(int16(len(ct.columnNames)))
for i, val := range values {
err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val)
err = encodePreparedStatementArgument(wbuf, ps.FieldDescriptions[i].DataType, val)
if err != nil {
ct.cancelCopyIn()
return 0, err

View File

@ -71,10 +71,7 @@ func (e SerializationError) Error() string {
return string(e)
}
// Encode encodes arg into wbuf as the type oid. This allows implementations
// of the Encoder interface to delegate the actual work of encoding to the
// built-in functionality.
func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
if arg == nil {
wbuf.WriteInt32(-1)
return nil
@ -112,7 +109,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
if err != nil {
return err
}
return Encode(wbuf, oid, v)
return encodePreparedStatementArgument(wbuf, oid, v)
case string:
return encodeString(wbuf, oid, arg)
case []byte:
@ -127,7 +124,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
return nil
}
arg = refVal.Elem().Interface()
return Encode(wbuf, oid, arg)
return encodePreparedStatementArgument(wbuf, oid, arg)
}
if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok {
@ -152,11 +149,31 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
}
if strippedArg, ok := stripNamedType(&refVal); ok {
return Encode(wbuf, oid, strippedArg)
return encodePreparedStatementArgument(wbuf, oid, strippedArg)
}
return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
}
// chooseParameterFormatCode determines the correct format code for an
// argument to a prepared statement. It defaults to TextFormatCode if no
// determination can be made.
func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.Oid, arg interface{}) int16 {
switch arg.(type) {
case pgtype.BinaryEncoder:
return BinaryFormatCode
case string, *string, pgtype.TextEncoder:
return TextFormatCode
}
if dt, ok := ci.DataTypeForOid(oid); ok {
if _, ok := dt.Value.(pgtype.BinaryEncoder); ok {
return BinaryFormatCode
}
}
return TextFormatCode
}
func stripNamedType(val *reflect.Value) (interface{}, bool) {
switch val.Kind() {
case reflect.Int: