mirror of https://github.com/jackc/pgx.git
Refactor encoding parameters for prepared statements
parent
9e289cb186
commit
a636ef31a4
24
conn.go
24
conn.go
|
@ -976,32 +976,12 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
||||||
|
|
||||||
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
|
wbuf.WriteInt16(int16(len(ps.ParameterOids)))
|
||||||
for i, oid := range ps.ParameterOids {
|
for i, oid := range ps.ParameterOids {
|
||||||
switch arguments[i].(type) {
|
wbuf.WriteInt16(chooseParameterFormatCode(c.ConnInfo, oid, arguments[i]))
|
||||||
case pgtype.BinaryEncoder:
|
|
||||||
wbuf.WriteInt16(BinaryFormatCode)
|
|
||||||
case pgtype.TextEncoder:
|
|
||||||
wbuf.WriteInt16(TextFormatCode)
|
|
||||||
case string, *string:
|
|
||||||
wbuf.WriteInt16(TextFormatCode)
|
|
||||||
default:
|
|
||||||
if dt, ok := c.ConnInfo.DataTypeForOid(oid); ok {
|
|
||||||
switch dt.Value.(type) {
|
|
||||||
case pgtype.BinaryEncoder:
|
|
||||||
wbuf.WriteInt16(BinaryFormatCode)
|
|
||||||
case pgtype.TextEncoder:
|
|
||||||
wbuf.WriteInt16(TextFormatCode)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("value for oid %v does not implement pgtype.BinaryEncoder or pgtype.TextEncoder", oid)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("unknown type for oid %v", oid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
wbuf.WriteInt16(int16(len(arguments)))
|
wbuf.WriteInt16(int16(len(arguments)))
|
||||||
for i, oid := range ps.ParameterOids {
|
for i, oid := range ps.ParameterOids {
|
||||||
if err := Encode(wbuf, oid, arguments[i]); err != nil {
|
if err := encodePreparedStatementArgument(wbuf, oid, arguments[i]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -157,7 +157,7 @@ func (ct *copyFrom) run() (int, error) {
|
||||||
|
|
||||||
wbuf.WriteInt16(int16(len(ct.columnNames)))
|
wbuf.WriteInt16(int16(len(ct.columnNames)))
|
||||||
for i, val := range values {
|
for i, val := range values {
|
||||||
err = Encode(wbuf, ps.FieldDescriptions[i].DataType, val)
|
err = encodePreparedStatementArgument(wbuf, ps.FieldDescriptions[i].DataType, val)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
ct.cancelCopyIn()
|
ct.cancelCopyIn()
|
||||||
return 0, err
|
return 0, err
|
||||||
|
|
31
values.go
31
values.go
|
@ -71,10 +71,7 @@ func (e SerializationError) Error() string {
|
||||||
return string(e)
|
return string(e)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode encodes arg into wbuf as the type oid. This allows implementations
|
func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
||||||
// of the Encoder interface to delegate the actual work of encoding to the
|
|
||||||
// built-in functionality.
|
|
||||||
func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
|
||||||
if arg == nil {
|
if arg == nil {
|
||||||
wbuf.WriteInt32(-1)
|
wbuf.WriteInt32(-1)
|
||||||
return nil
|
return nil
|
||||||
|
@ -112,7 +109,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return Encode(wbuf, oid, v)
|
return encodePreparedStatementArgument(wbuf, oid, v)
|
||||||
case string:
|
case string:
|
||||||
return encodeString(wbuf, oid, arg)
|
return encodeString(wbuf, oid, arg)
|
||||||
case []byte:
|
case []byte:
|
||||||
|
@ -127,7 +124,7 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
arg = refVal.Elem().Interface()
|
arg = refVal.Elem().Interface()
|
||||||
return Encode(wbuf, oid, arg)
|
return encodePreparedStatementArgument(wbuf, oid, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok {
|
if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok {
|
||||||
|
@ -152,11 +149,31 @@ func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if strippedArg, ok := stripNamedType(&refVal); ok {
|
if strippedArg, ok := stripNamedType(&refVal); ok {
|
||||||
return Encode(wbuf, oid, strippedArg)
|
return encodePreparedStatementArgument(wbuf, oid, strippedArg)
|
||||||
}
|
}
|
||||||
return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
|
return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// chooseParameterFormatCode determines the correct format code for an
|
||||||
|
// argument to a prepared statement. It defaults to TextFormatCode if no
|
||||||
|
// determination can be made.
|
||||||
|
func chooseParameterFormatCode(ci *pgtype.ConnInfo, oid pgtype.Oid, arg interface{}) int16 {
|
||||||
|
switch arg.(type) {
|
||||||
|
case pgtype.BinaryEncoder:
|
||||||
|
return BinaryFormatCode
|
||||||
|
case string, *string, pgtype.TextEncoder:
|
||||||
|
return TextFormatCode
|
||||||
|
}
|
||||||
|
|
||||||
|
if dt, ok := ci.DataTypeForOid(oid); ok {
|
||||||
|
if _, ok := dt.Value.(pgtype.BinaryEncoder); ok {
|
||||||
|
return BinaryFormatCode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return TextFormatCode
|
||||||
|
}
|
||||||
|
|
||||||
func stripNamedType(val *reflect.Value) (interface{}, bool) {
|
func stripNamedType(val *reflect.Value) (interface{}, bool) {
|
||||||
switch val.Kind() {
|
switch val.Kind() {
|
||||||
case reflect.Int:
|
case reflect.Int:
|
||||||
|
|
Loading…
Reference in New Issue