pgx/values.go

451 lines
12 KiB
Go

package pgx
import (
"bytes"
"database/sql/driver"
"encoding/json"
"fmt"
"reflect"
"github.com/jackc/pgx/pgtype"
)
// PostgreSQL oids for common types
const (
BoolOid = 16
ByteaOid = 17
CharOid = 18
NameOid = 19
Int8Oid = 20
Int2Oid = 21
Int4Oid = 23
TextOid = 25
OidOid = 26
TidOid = 27
XidOid = 28
CidOid = 29
JsonOid = 114
CidrOid = 650
CidrArrayOid = 651
Float4Oid = 700
Float8Oid = 701
UnknownOid = 705
InetOid = 869
BoolArrayOid = 1000
Int2ArrayOid = 1005
Int4ArrayOid = 1007
TextArrayOid = 1009
ByteaArrayOid = 1001
VarcharArrayOid = 1015
Int8ArrayOid = 1016
Float4ArrayOid = 1021
Float8ArrayOid = 1022
AclitemOid = 1033
AclitemArrayOid = 1034
InetArrayOid = 1041
VarcharOid = 1043
DateOid = 1082
TimestampOid = 1114
TimestampArrayOid = 1115
DateArrayOid = 1182
TimestampTzOid = 1184
TimestampTzArrayOid = 1185
RecordOid = 2249
UuidOid = 2950
JsonbOid = 3802
)
// PostgreSQL format codes
const (
TextFormatCode = 0
BinaryFormatCode = 1
)
const maxUint = ^uint(0)
const maxInt = int(maxUint >> 1)
const minInt = -maxInt - 1
// DefaultTypeFormats maps type names to their default requested format (text
// or binary). In theory the Scanner interface should be the one to determine
// the format of the returned values. However, the query has already been
// executed by the time Scan is called so it has no chance to set the format.
// So for types that should always be returned in binary the format should be
// set here.
var DefaultTypeFormats map[string]int16
// internalNativeGoTypeFormats lists the encoding type for native Go types (not handled with Encoder interface)
var internalNativeGoTypeFormats map[pgtype.Oid]int16
func init() {
DefaultTypeFormats = map[string]int16{
"_aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin)
"_bool": BinaryFormatCode,
"_bytea": BinaryFormatCode,
"_cidr": BinaryFormatCode,
"_float4": BinaryFormatCode,
"_float8": BinaryFormatCode,
"_inet": BinaryFormatCode,
"_int2": BinaryFormatCode,
"_int4": BinaryFormatCode,
"_int8": BinaryFormatCode,
"_text": BinaryFormatCode,
"_timestamp": BinaryFormatCode,
"_timestamptz": BinaryFormatCode,
"_varchar": BinaryFormatCode,
"aclitem": TextFormatCode, // Pg's src/backend/utils/adt/acl.c has only in/out (text) not send/recv (bin)
"bool": BinaryFormatCode,
"bytea": BinaryFormatCode,
"char": BinaryFormatCode,
"cid": BinaryFormatCode,
"cidr": BinaryFormatCode,
"date": BinaryFormatCode,
"float4": BinaryFormatCode,
"float8": BinaryFormatCode,
"inet": BinaryFormatCode,
"int2": BinaryFormatCode,
"int4": BinaryFormatCode,
"int8": BinaryFormatCode,
"oid": BinaryFormatCode,
"record": BinaryFormatCode,
"tid": BinaryFormatCode,
"timestamp": BinaryFormatCode,
"timestamptz": BinaryFormatCode,
"xid": BinaryFormatCode,
}
internalNativeGoTypeFormats = map[pgtype.Oid]int16{
BoolArrayOid: BinaryFormatCode,
BoolOid: BinaryFormatCode,
ByteaArrayOid: BinaryFormatCode,
ByteaOid: BinaryFormatCode,
CidrArrayOid: BinaryFormatCode,
CidrOid: BinaryFormatCode,
DateOid: BinaryFormatCode,
Float4ArrayOid: BinaryFormatCode,
Float4Oid: BinaryFormatCode,
Float8ArrayOid: BinaryFormatCode,
Float8Oid: BinaryFormatCode,
InetArrayOid: BinaryFormatCode,
InetOid: BinaryFormatCode,
Int2ArrayOid: BinaryFormatCode,
Int2Oid: BinaryFormatCode,
Int4ArrayOid: BinaryFormatCode,
Int4Oid: BinaryFormatCode,
Int8ArrayOid: BinaryFormatCode,
Int8Oid: BinaryFormatCode,
JsonbOid: BinaryFormatCode,
JsonOid: BinaryFormatCode,
OidOid: BinaryFormatCode,
RecordOid: BinaryFormatCode,
TextArrayOid: BinaryFormatCode,
TimestampArrayOid: BinaryFormatCode,
TimestampOid: BinaryFormatCode,
TimestampTzArrayOid: BinaryFormatCode,
TimestampTzOid: BinaryFormatCode,
VarcharArrayOid: BinaryFormatCode,
}
}
// SerializationError occurs on failure to encode or decode a value
type SerializationError string
func (e SerializationError) Error() string {
return string(e)
}
// Encode encodes arg into wbuf as the type oid. This allows implementations
// of the Encoder interface to delegate the actual work of encoding to the
// built-in functionality.
func Encode(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
if arg == nil {
wbuf.WriteInt32(-1)
return nil
}
switch arg := arg.(type) {
case pgtype.BinaryEncoder:
buf := &bytes.Buffer{}
null, err := arg.EncodeBinary(wbuf.conn.ConnInfo, buf)
if err != nil {
return err
}
if null {
wbuf.WriteInt32(-1)
} else {
wbuf.WriteInt32(int32(buf.Len()))
wbuf.WriteBytes(buf.Bytes())
}
return nil
case pgtype.TextEncoder:
buf := &bytes.Buffer{}
null, err := arg.EncodeText(wbuf.conn.ConnInfo, buf)
if err != nil {
return err
}
if null {
wbuf.WriteInt32(-1)
} else {
wbuf.WriteInt32(int32(buf.Len()))
wbuf.WriteBytes(buf.Bytes())
}
return nil
case driver.Valuer:
v, err := arg.Value()
if err != nil {
return err
}
return Encode(wbuf, oid, v)
case string:
return encodeString(wbuf, oid, arg)
case []byte:
return encodeByteSlice(wbuf, oid, arg)
}
refVal := reflect.ValueOf(arg)
if refVal.Kind() == reflect.Ptr {
if refVal.IsNil() {
wbuf.WriteInt32(-1)
return nil
}
arg = refVal.Elem().Interface()
return Encode(wbuf, oid, arg)
}
if dt, ok := wbuf.conn.ConnInfo.DataTypeForOid(oid); ok {
value := dt.Value
err := value.Set(arg)
if err != nil {
return err
}
buf := &bytes.Buffer{}
null, err := value.(pgtype.BinaryEncoder).EncodeBinary(wbuf.conn.ConnInfo, buf)
if err != nil {
return err
}
if null {
wbuf.WriteInt32(-1)
} else {
wbuf.WriteInt32(int32(buf.Len()))
wbuf.WriteBytes(buf.Bytes())
}
return nil
}
if strippedArg, ok := stripNamedType(&refVal); ok {
return Encode(wbuf, oid, strippedArg)
}
return SerializationError(fmt.Sprintf("Cannot encode %T into oid %v - %T must implement Encoder or be converted to a string", arg, oid, arg))
}
func stripNamedType(val *reflect.Value) (interface{}, bool) {
switch val.Kind() {
case reflect.Int:
convVal := int(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Int8:
convVal := int8(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Int16:
convVal := int16(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Int32:
convVal := int32(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Int64:
convVal := int64(val.Int())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint:
convVal := uint(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint8:
convVal := uint8(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint16:
convVal := uint16(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint32:
convVal := uint32(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.Uint64:
convVal := uint64(val.Uint())
return convVal, reflect.TypeOf(convVal) != val.Type()
case reflect.String:
convVal := val.String()
return convVal, reflect.TypeOf(convVal) != val.Type()
}
return nil, false
}
// Decode decodes from vr into d. d must be a pointer. This allows
// implementations of the Decoder interface to delegate the actual work of
// decoding to the built-in functionality.
func Decode(vr *ValueReader, d interface{}) error {
switch v := d.(type) {
case *string:
*v = decodeText(vr)
default:
if v := reflect.ValueOf(d); v.Kind() == reflect.Ptr {
el := v.Elem()
switch el.Kind() {
// if d is a pointer to pointer, strip the pointer and try again
case reflect.Ptr:
// -1 is a null value
if vr.Len() == -1 {
if !el.IsNil() {
// if the destination pointer is not nil, nil it out
el.Set(reflect.Zero(el.Type()))
}
return nil
}
if el.IsNil() {
// allocate destination
el.Set(reflect.New(el.Type().Elem()))
}
d = el.Interface()
return Decode(vr, d)
case reflect.String:
el.SetString(decodeText(vr))
return nil
}
}
return fmt.Errorf("Scan cannot decode into %T", d)
}
return nil
}
func decodeText(vr *ValueReader) string {
if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null into string"))
return ""
}
if vr.Type().FormatCode == BinaryFormatCode {
vr.Fatal(ProtocolError("cannot decode binary value into string"))
return ""
}
return vr.ReadString(vr.Len())
}
func decodeTextAllowBinary(vr *ValueReader) string {
if vr.Len() == -1 {
vr.Fatal(ProtocolError("Cannot decode null into string"))
return ""
}
return vr.ReadString(vr.Len())
}
func encodeString(w *WriteBuf, oid pgtype.Oid, value string) error {
w.WriteInt32(int32(len(value)))
w.WriteBytes([]byte(value))
return nil
}
func decodeBytea(vr *ValueReader) []byte {
if vr.Len() == -1 {
return nil
}
if vr.Type().DataType != ByteaOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into []byte", vr.Type().DataType)))
return nil
}
if vr.Type().FormatCode != BinaryFormatCode {
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return nil
}
return vr.ReadBytes(vr.Len())
}
func encodeByteSlice(w *WriteBuf, oid pgtype.Oid, value []byte) error {
w.WriteInt32(int32(len(value)))
w.WriteBytes(value)
return nil
}
func decodeJSON(vr *ValueReader, d interface{}) error {
if vr.Len() == -1 {
return nil
}
if vr.Type().DataType != JsonOid {
vr.Fatal(ProtocolError(fmt.Sprintf("Cannot decode oid %v into json", vr.Type().DataType)))
}
bytes := vr.ReadBytes(vr.Len())
err := json.Unmarshal(bytes, d)
if err != nil {
vr.Fatal(err)
}
return err
}
func encodeJSON(w *WriteBuf, oid pgtype.Oid, value interface{}) error {
if oid != JsonOid {
return fmt.Errorf("cannot encode JSON into oid %v", oid)
}
s, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("Failed to encode json from type: %T", value)
}
w.WriteInt32(int32(len(s)))
w.WriteBytes(s)
return nil
}
func decodeJSONB(vr *ValueReader, d interface{}) error {
if vr.Len() == -1 {
return nil
}
if vr.Type().DataType != JsonbOid {
err := ProtocolError(fmt.Sprintf("Cannot decode oid %v into jsonb", vr.Type().DataType))
vr.Fatal(err)
return err
}
bytes := vr.ReadBytes(vr.Len())
if vr.Type().FormatCode == BinaryFormatCode {
if bytes[0] != 1 {
err := ProtocolError(fmt.Sprintf("Unknown jsonb format byte: %x", bytes[0]))
vr.Fatal(err)
return err
}
bytes = bytes[1:]
}
err := json.Unmarshal(bytes, d)
if err != nil {
vr.Fatal(err)
}
return err
}
func encodeJSONB(w *WriteBuf, oid pgtype.Oid, value interface{}) error {
if oid != JsonbOid {
return fmt.Errorf("cannot encode JSON into oid %v", oid)
}
s, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("Failed to encode json from type: %T", value)
}
w.WriteInt32(int32(len(s) + 1))
w.WriteByte(1) // JSONB format header
w.WriteBytes(s)
return nil
}