pgx/value_transcoder.go

730 lines
17 KiB
Go

package pgx
import (
"bytes"
"encoding/hex"
"fmt"
"math"
"regexp"
"strconv"
"time"
"unsafe"
)
const (
BoolOid = 16
ByteaOid = 17
Int8Oid = 20
Int2Oid = 21
Int4Oid = 23
TextOid = 25
Float4Oid = 700
Float8Oid = 701
Int2ArrayOid = 1005
Int4ArrayOid = 1007
Int8ArrayOid = 1016
VarcharOid = 1043
DateOid = 1082
TimestampTzOid = 1184
)
// ValueTranscoder stores all the data necessary to encode and decode values from
// a PostgreSQL server
type ValueTranscoder struct {
// DecodeText decodes values returned from the server in text format
DecodeText func(*MessageReader, int32) interface{}
// DecodeBinary decodes values returned from the server in binary format
DecodeBinary func(*MessageReader, int32) interface{}
// EncodeTo encodes values to send to the server
EncodeTo func(*WriteBuf, interface{}) error
// EncodeFormat is the format values are encoded for transmission.
// 0 = text
// 1 = binary
EncodeFormat int16
}
// ValueTranscoders is used to transcode values being sent to and received from
// the PostgreSQL server. Additional types can be transcoded by adding a
// *ValueTranscoder for the appropriate Oid to the map.
var ValueTranscoders map[Oid]*ValueTranscoder
var defaultTranscoder *ValueTranscoder
func init() {
ValueTranscoders = make(map[Oid]*ValueTranscoder)
// bool
ValueTranscoders[BoolOid] = &ValueTranscoder{
DecodeText: decodeBoolFromText,
DecodeBinary: decodeBoolFromBinary,
EncodeTo: encodeBool,
EncodeFormat: 1}
// bytea
ValueTranscoders[ByteaOid] = &ValueTranscoder{
DecodeText: decodeByteaFromText,
EncodeTo: encodeBytea,
EncodeFormat: 1}
// int8
ValueTranscoders[Int8Oid] = &ValueTranscoder{
DecodeText: decodeInt8FromText,
DecodeBinary: decodeInt8FromBinary,
EncodeTo: encodeInt8,
EncodeFormat: 1}
// int2
ValueTranscoders[Int2Oid] = &ValueTranscoder{
DecodeText: decodeInt2FromText,
DecodeBinary: decodeInt2FromBinary,
EncodeTo: encodeInt2,
EncodeFormat: 1}
// int4
ValueTranscoders[Int4Oid] = &ValueTranscoder{
DecodeText: decodeInt4FromText,
DecodeBinary: decodeInt4FromBinary,
EncodeTo: encodeInt4,
EncodeFormat: 1}
// text
ValueTranscoders[TextOid] = &ValueTranscoder{
DecodeText: decodeTextFromText,
EncodeTo: encodeText}
// float4
ValueTranscoders[Float4Oid] = &ValueTranscoder{
DecodeText: decodeFloat4FromText,
DecodeBinary: decodeFloat4FromBinary,
EncodeTo: encodeFloat4,
EncodeFormat: 1}
// float8
ValueTranscoders[Float8Oid] = &ValueTranscoder{
DecodeText: decodeFloat8FromText,
DecodeBinary: decodeFloat8FromBinary,
EncodeTo: encodeFloat8,
EncodeFormat: 1}
// int2[]
ValueTranscoders[Int2ArrayOid] = &ValueTranscoder{
DecodeText: decodeInt2ArrayFromText,
EncodeTo: encodeInt2Array}
// int4[]
ValueTranscoders[Int4ArrayOid] = &ValueTranscoder{
DecodeText: decodeInt4ArrayFromText,
EncodeTo: encodeInt4Array}
// int8[]
ValueTranscoders[Int8ArrayOid] = &ValueTranscoder{
DecodeText: decodeInt8ArrayFromText,
EncodeTo: encodeInt8Array}
// varchar -- same as text
ValueTranscoders[VarcharOid] = ValueTranscoders[Oid(25)]
// date
ValueTranscoders[DateOid] = &ValueTranscoder{
DecodeText: decodeDateFromText,
DecodeBinary: decodeDateFromBinary,
EncodeTo: encodeDate}
// timestamptz
ValueTranscoders[TimestampTzOid] = &ValueTranscoder{
DecodeText: decodeTimestampTzFromText,
DecodeBinary: decodeTimestampTzFromBinary,
EncodeTo: encodeTimestampTz}
// use text transcoder for anything we don't understand
defaultTranscoder = ValueTranscoders[TextOid]
}
var arrayEl *regexp.Regexp = regexp.MustCompile(`[{,](?:"((?:[^"\\]|\\.)*)"|(NULL)|([^,}]+))`)
// SplitArrayText is used by array transcoders to split array text into elements
func SplitArrayText(text string) (elements []string) {
matches := arrayEl.FindAllStringSubmatch(text, -1)
elements = make([]string, 0, len(matches))
for _, match := range matches {
if match[1] != "" {
elements = append(elements, match[1])
} else if match[2] != "" {
elements = append(elements, match[2])
} else if match[3] != "" {
elements = append(elements, match[3])
}
}
return
}
func decodeBoolFromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size)
switch s {
case "t":
return true
case "f":
return false
default:
return ProtocolError(fmt.Sprintf("Received invalid bool: %v", s))
}
}
func decodeBoolFromBinary(mr *MessageReader, size int32) interface{} {
if size != 1 {
return ProtocolError(fmt.Sprintf("Received an invalid size for an bool: %d", size))
}
b := mr.ReadByte()
return b != 0
}
func encodeBool(w *WriteBuf, value interface{}) error {
v, ok := value.(bool)
if !ok {
return fmt.Errorf("Expected bool, received %T", value)
}
w.WriteInt32(1)
var n byte
if v {
n = 1
}
w.WriteByte(n)
return nil
}
func decodeInt8FromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size)
n, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return ProtocolError(fmt.Sprintf("Received invalid int8: %v", s))
}
return n
}
func decodeInt8FromBinary(mr *MessageReader, size int32) interface{} {
if size != 8 {
return ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", size))
}
return mr.ReadInt64()
}
func encodeInt8(w *WriteBuf, value interface{}) error {
var v int64
switch value := value.(type) {
case int8:
v = int64(value)
case uint8:
v = int64(value)
case int16:
v = int64(value)
case uint16:
v = int64(value)
case int32:
v = int64(value)
case uint32:
v = int64(value)
case int64:
v = int64(value)
case uint64:
if value > math.MaxInt64 {
return fmt.Errorf("uint64 %d is larger than max int64 %d", value, math.MaxInt64)
}
v = int64(value)
case int:
v = int64(value)
default:
return fmt.Errorf("Expected integer representable in int64, received %T %v", value, value)
}
w.WriteInt32(8)
w.WriteInt64(v)
return nil
}
func decodeInt2FromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size)
n, err := strconv.ParseInt(s, 10, 16)
if err != nil {
return ProtocolError(fmt.Sprintf("Received invalid int2: %v", s))
}
return int16(n)
}
func decodeInt2FromBinary(mr *MessageReader, size int32) interface{} {
if size != 2 {
return ProtocolError(fmt.Sprintf("Received an invalid size for an int2: %d", size))
}
return mr.ReadInt16()
}
func encodeInt2(w *WriteBuf, value interface{}) error {
var v int16
switch value := value.(type) {
case int8:
v = int16(value)
case uint8:
v = int16(value)
case int16:
v = int16(value)
case uint16:
if value > math.MaxInt16 {
return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16)
}
v = int16(value)
case int32:
if value > math.MaxInt16 {
return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16)
}
v = int16(value)
case uint32:
if value > math.MaxInt16 {
return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16)
}
v = int16(value)
case int64:
if value > math.MaxInt16 {
return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16)
}
v = int16(value)
case uint64:
if value > math.MaxInt16 {
return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16)
}
v = int16(value)
case int:
if value > math.MaxInt16 {
return fmt.Errorf("%T %d is larger than max int16 %d", value, value, math.MaxInt16)
}
v = int16(value)
default:
return fmt.Errorf("Expected integer representable in int16, received %T %v", value, value)
}
w.WriteInt32(2)
w.WriteInt16(v)
return nil
}
func decodeInt4FromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size)
n, err := strconv.ParseInt(s, 10, 32)
if err != nil {
return ProtocolError(fmt.Sprintf("Received invalid int4: %v", s))
}
return int32(n)
}
func decodeInt4FromBinary(mr *MessageReader, size int32) interface{} {
if size != 4 {
return ProtocolError(fmt.Sprintf("Received an invalid size for an int4: %d", size))
}
return mr.ReadInt32()
}
func encodeInt4(w *WriteBuf, value interface{}) error {
var v int32
switch value := value.(type) {
case int8:
v = int32(value)
case uint8:
v = int32(value)
case int16:
v = int32(value)
case uint16:
v = int32(value)
case int32:
v = int32(value)
case uint32:
if value > math.MaxInt32 {
return fmt.Errorf("%T %d is larger than max int64 %d", value, value, math.MaxInt32)
}
v = int32(value)
case int64:
if value > math.MaxInt32 {
return fmt.Errorf("%T %d is larger than max int64 %d", value, value, math.MaxInt32)
}
v = int32(value)
case uint64:
if value > math.MaxInt32 {
return fmt.Errorf("%T %d is larger than max int64 %d", value, value, math.MaxInt32)
}
v = int32(value)
case int:
if value > math.MaxInt32 {
return fmt.Errorf("%T %d is larger than max int64 %d", value, value, math.MaxInt32)
}
v = int32(value)
default:
return fmt.Errorf("Expected integer representable in int32, received %T %v", value, value)
}
w.WriteInt32(4)
w.WriteInt32(v)
return nil
}
func decodeFloat4FromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size)
n, err := strconv.ParseFloat(s, 32)
if err != nil {
return ProtocolError(fmt.Sprintf("Received invalid float4: %v", s))
}
return float32(n)
}
func decodeFloat4FromBinary(mr *MessageReader, size int32) interface{} {
if size != 4 {
return ProtocolError(fmt.Sprintf("Received an invalid size for an float4: %d", size))
}
i := mr.ReadInt32()
p := unsafe.Pointer(&i)
return *(*float32)(p)
}
func encodeFloat4(w *WriteBuf, value interface{}) error {
var v float32
switch value := value.(type) {
case float32:
v = float32(value)
case float64:
if value > math.MaxFloat32 {
return fmt.Errorf("%T %f is larger than max float32 %f", value, math.MaxFloat32)
}
v = float32(value)
default:
return fmt.Errorf("Expected float representable in float32, received %T %v", value, value)
}
w.WriteInt32(4)
p := unsafe.Pointer(&v)
w.WriteInt32(*(*int32)(p))
return nil
}
func decodeFloat8FromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size)
v, err := strconv.ParseFloat(s, 64)
if err != nil {
return ProtocolError(fmt.Sprintf("Received invalid float8: %v", s))
}
return v
}
func decodeFloat8FromBinary(mr *MessageReader, size int32) interface{} {
if size != 8 {
return ProtocolError(fmt.Sprintf("Received an invalid size for an float8: %d", size))
}
i := mr.ReadInt64()
p := unsafe.Pointer(&i)
return *(*float64)(p)
}
func encodeFloat8(w *WriteBuf, value interface{}) error {
var v float64
switch value := value.(type) {
case float32:
v = float64(value)
case float64:
v = float64(value)
default:
return fmt.Errorf("Expected float representable in float64, received %T %v", value, value)
}
w.WriteInt32(8)
p := unsafe.Pointer(&v)
w.WriteInt64(*(*int64)(p))
return nil
}
func decodeTextFromText(mr *MessageReader, size int32) interface{} {
return mr.ReadString(size)
}
func encodeText(w *WriteBuf, value interface{}) error {
s, ok := value.(string)
if !ok {
return fmt.Errorf("Expected string, received %T", value)
}
w.WriteInt32(int32(len(s)))
w.WriteBytes([]byte(s))
return nil
}
func decodeByteaFromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size)
b, err := hex.DecodeString(s[2:])
if err != nil {
return ProtocolError(fmt.Sprintf("Can't decode byte array: %v - %v", err, s))
}
return b
}
func encodeBytea(w *WriteBuf, value interface{}) error {
b, ok := value.([]byte)
if !ok {
return fmt.Errorf("Expected []byte, received %T", value)
}
w.WriteInt32(int32(len(b)))
w.WriteBytes(b)
return nil
}
func decodeDateFromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size)
t, err := time.ParseInLocation("2006-01-02", s, time.Local)
if err != nil {
return ProtocolError(fmt.Sprintf("Can't decode date: %v", s))
}
return t
}
func decodeDateFromBinary(mr *MessageReader, size int32) interface{} {
dayOffset := mr.ReadInt32()
return time.Date(2000, 1, int(1+dayOffset), 0, 0, 0, 0, time.Local)
}
func encodeDate(w *WriteBuf, value interface{}) error {
t, ok := value.(time.Time)
if !ok {
return fmt.Errorf("Expected time.Time, received %T", value)
}
s := t.Format("2006-01-02")
return encodeText(w, s)
}
func decodeTimestampTzFromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size)
t, err := time.Parse("2006-01-02 15:04:05.999999-07", s)
if err != nil {
return ProtocolError(fmt.Sprintf("Can't decode timestamptz: %v - %v", err, s))
}
return t
}
func decodeTimestampTzFromBinary(mr *MessageReader, size int32) interface{} {
if size != 8 {
return ProtocolError(fmt.Sprintf("Received an invalid size for an int8: %d", size))
}
microsecFromUnixEpochToY2K := int64(946684800 * 1000000)
microsecSinceY2K := mr.ReadInt64()
microsecSinceUnixEpoch := microsecFromUnixEpochToY2K + microsecSinceY2K
return time.Unix(microsecSinceUnixEpoch/1000000, (microsecSinceUnixEpoch%1000000)*1000)
// 2000-01-01 00:00:00 in 946684800
// 946684800 * 1000000
}
func encodeTimestampTz(w *WriteBuf, value interface{}) error {
t, ok := value.(time.Time)
if !ok {
return fmt.Errorf("Expected time.Time, received %T", value)
}
s := t.Format("2006-01-02 15:04:05.999999 -0700")
return encodeText(w, s)
}
func decodeInt2ArrayFromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size)
elements := SplitArrayText(s)
numbers := make([]int16, 0, len(elements))
for _, e := range elements {
n, err := strconv.ParseInt(e, 10, 16)
if err != nil {
return ProtocolError(fmt.Sprintf("Received invalid int2[]: %v", s))
}
numbers = append(numbers, int16(n))
}
return numbers
}
func int16SliceToArrayString(nums []int16) (string, error) {
w := &bytes.Buffer{}
_, err := w.WriteString("{")
if err != nil {
return "", err
}
for i, n := range nums {
if i > 0 {
_, err = w.WriteString(",")
if err != nil {
return "", err
}
}
_, err = w.WriteString(strconv.FormatInt(int64(n), 10))
if err != nil {
return "", err
}
}
_, err = w.WriteString("}")
if err != nil {
return "", err
}
return w.String(), nil
}
func encodeInt2Array(w *WriteBuf, value interface{}) error {
v, ok := value.([]int16)
if !ok {
return fmt.Errorf("Expected []int16, received %T", value)
}
s, err := int16SliceToArrayString(v)
if err != nil {
return fmt.Errorf("Failed to encode []int16: %v", err)
}
return encodeText(w, s)
}
func decodeInt4ArrayFromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size)
elements := SplitArrayText(s)
numbers := make([]int32, 0, len(elements))
for _, e := range elements {
n, err := strconv.ParseInt(e, 10, 16)
if err != nil {
return ProtocolError(fmt.Sprintf("Received invalid int4[]: %v", s))
}
numbers = append(numbers, int32(n))
}
return numbers
}
func int32SliceToArrayString(nums []int32) (string, error) {
w := &bytes.Buffer{}
_, err := w.WriteString("{")
if err != nil {
return "", err
}
for i, n := range nums {
if i > 0 {
_, err = w.WriteString(",")
if err != nil {
return "", err
}
}
_, err = w.WriteString(strconv.FormatInt(int64(n), 10))
if err != nil {
return "", err
}
}
_, err = w.WriteString("}")
if err != nil {
return "", err
}
return w.String(), nil
}
func encodeInt4Array(w *WriteBuf, value interface{}) error {
v, ok := value.([]int32)
if !ok {
return fmt.Errorf("Expected []int32, received %T", value)
}
s, err := int32SliceToArrayString(v)
if err != nil {
return fmt.Errorf("Failed to encode []int32: %v", err)
}
return encodeText(w, s)
}
func decodeInt8ArrayFromText(mr *MessageReader, size int32) interface{} {
s := mr.ReadString(size)
elements := SplitArrayText(s)
numbers := make([]int64, 0, len(elements))
for _, e := range elements {
n, err := strconv.ParseInt(e, 10, 16)
if err != nil {
return ProtocolError(fmt.Sprintf("Received invalid int8[]: %v", s))
}
numbers = append(numbers, int64(n))
}
return numbers
}
func int64SliceToArrayString(nums []int64) (string, error) {
w := &bytes.Buffer{}
_, err := w.WriteString("{")
if err != nil {
return "", err
}
for i, n := range nums {
if i > 0 {
_, err = w.WriteString(",")
if err != nil {
return "", err
}
}
_, err = w.WriteString(strconv.FormatInt(int64(n), 10))
if err != nil {
return "", err
}
}
_, err = w.WriteString("}")
if err != nil {
return "", err
}
return w.String(), nil
}
func encodeInt8Array(w *WriteBuf, value interface{}) error {
v, ok := value.([]int64)
if !ok {
return fmt.Errorf("Expected []int64, received %T", value)
}
s, err := int64SliceToArrayString(v)
if err != nil {
return fmt.Errorf("Failed to encode []int64: %v", err)
}
return encodeText(w, s)
}