Add bytea

v3-numeric-wip
Jack Christensen 2017-03-09 21:07:40 -06:00
parent fa36ad9196
commit bac4af13bb
4 changed files with 255 additions and 0 deletions

160
pgtype/bytea.go Normal file
View File

@ -0,0 +1,160 @@
package pgtype
import (
"encoding/hex"
"fmt"
"io"
"reflect"
"github.com/jackc/pgx/pgio"
)
type Bytea struct {
Bytes []byte
Status Status
}
func (dst *Bytea) ConvertFrom(src interface{}) error {
switch value := src.(type) {
case Bytea:
*dst = value
case []byte:
if value != nil {
*dst = Bytea{Bytes: value, Status: Present}
} else {
*dst = Bytea{Status: Null}
}
default:
if originalSrc, ok := underlyingBytesType(src); ok {
return dst.ConvertFrom(originalSrc)
}
return fmt.Errorf("cannot convert %v to Bytea", value)
}
return nil
}
func (src *Bytea) AssignTo(dst interface{}) error {
switch v := dst.(type) {
case *[]byte:
if src.Status == Present {
*v = src.Bytes
} else {
*v = nil
}
default:
if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr {
el := v.Elem()
switch el.Kind() {
// if dst is a pointer to pointer, strip the pointer and try again
case reflect.Ptr:
if src.Status == Null {
el.Set(reflect.Zero(el.Type()))
return nil
}
if el.IsNil() {
// allocate destination
el.Set(reflect.New(el.Type().Elem()))
}
return src.AssignTo(el.Interface())
default:
if originalDst, ok := underlyingPtrSliceType(dst); ok {
return src.AssignTo(originalDst)
}
}
}
return fmt.Errorf("cannot decode %v into %T", src, dst)
}
return nil
}
// DecodeText only supports the hex format. This has been the default since
// PostgreSQL 9.0.
func (dst *Bytea) DecodeText(r io.Reader) error {
size, err := pgio.ReadInt32(r)
if err != nil {
return err
}
if size == -1 {
*dst = Bytea{Status: Null}
return nil
}
sbuf := make([]byte, int(size))
_, err = io.ReadFull(r, sbuf)
if err != nil {
return err
}
if len(sbuf) < 2 || sbuf[0] != '\\' || sbuf[1] != 'x' {
return fmt.Errorf("invalid hex format")
}
buf := make([]byte, (len(sbuf)-2)/2)
_, err = hex.Decode(buf, sbuf[2:])
if err != nil {
return err
}
*dst = Bytea{Bytes: buf, Status: Present}
return nil
}
func (dst *Bytea) DecodeBinary(r io.Reader) error {
size, err := pgio.ReadInt32(r)
if err != nil {
return err
}
if size == -1 {
*dst = Bytea{Status: Null}
return nil
}
buf := make([]byte, int(size))
_, err = io.ReadFull(r, buf)
if err != nil {
return err
}
*dst = Bytea{Bytes: buf, Status: Present}
return nil
}
func (src Bytea) EncodeText(w io.Writer) error {
if done, err := encodeNotPresent(w, src.Status); done {
return err
}
str := hex.EncodeToString(src.Bytes)
_, err := pgio.WriteInt32(w, int32(len(str)+2))
if err != nil {
return nil
}
_, err = io.WriteString(w, `\x`)
if err != nil {
return nil
}
_, err = io.WriteString(w, str)
return err
}
func (src Bytea) EncodeBinary(w io.Writer) error {
if done, err := encodeNotPresent(w, src.Status); done {
return err
}
_, err := pgio.WriteInt32(w, int32(len(src.Bytes)))
if err != nil {
return nil
}
_, err = w.Write(src.Bytes)
return err
}

73
pgtype/bytea_test.go Normal file
View File

@ -0,0 +1,73 @@
package pgtype_test
import (
"reflect"
"testing"
"github.com/jackc/pgx/pgtype"
)
func TestByteaTranscode(t *testing.T) {
testSuccessfulTranscode(t, "bytea", []interface{}{
pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present},
pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present},
pgtype.Bytea{Bytes: nil, Status: pgtype.Null},
})
}
func TestByteaConvertFrom(t *testing.T) {
successfulTests := []struct {
source interface{}
result pgtype.Bytea
}{
{source: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Null}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Null}},
{source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}},
{source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}},
{source: []byte(nil), result: pgtype.Bytea{Status: pgtype.Null}},
{source: _byteSlice{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}},
{source: _byteSlice(nil), result: pgtype.Bytea{Status: pgtype.Null}},
}
for i, tt := range successfulTests {
var r pgtype.Bytea
err := r.ConvertFrom(tt.source)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if !reflect.DeepEqual(r, tt.result) {
t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r)
}
}
}
func TestByteaAssignTo(t *testing.T) {
var buf []byte
var _buf _byteSlice
var pbuf *[]byte
var _pbuf *_byteSlice
simpleTests := []struct {
src pgtype.Bytea
dst interface{}
expected interface{}
}{
{src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &buf, expected: []byte{1, 2, 3}},
{src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_buf, expected: _byteSlice{1, 2, 3}},
{src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &pbuf, expected: &[]byte{1, 2, 3}},
{src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_pbuf, expected: &_byteSlice{1, 2, 3}},
{src: pgtype.Bytea{Status: pgtype.Null}, dst: &pbuf, expected: ((*[]byte)(nil))},
{src: pgtype.Bytea{Status: pgtype.Null}, dst: &_pbuf, expected: ((*_byteSlice)(nil))},
}
for i, tt := range simpleTests {
err := tt.src.AssignTo(tt.dst)
if err != nil {
t.Errorf("%d: %v", i, err)
}
if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) {
t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst)
}
}
}

View File

@ -85,6 +85,27 @@ func underlyingBoolType(val interface{}) (interface{}, bool) {
return nil, false
}
// underlyingBytesType gets the underlying type that can be converted to []byte
func underlyingBytesType(val interface{}) (interface{}, bool) {
refVal := reflect.ValueOf(val)
switch refVal.Kind() {
case reflect.Ptr:
if refVal.IsNil() {
return nil, false
}
convVal := refVal.Elem().Interface()
return convVal, true
case reflect.Slice:
if refVal.Type().Elem().Kind() == reflect.Uint8 {
convVal := refVal.Bytes()
return convVal, reflect.TypeOf(convVal) != refVal.Type()
}
}
return nil, false
}
// underlyingStringType gets the underlying type that can be converted to String
func underlyingStringType(val interface{}) (interface{}, bool) {
refVal := reflect.ValueOf(val)

View File

@ -22,6 +22,7 @@ type _int32Slice []int32
type _int64Slice []int64
type _float32Slice []float32
type _float64Slice []float64
type _byteSlice []byte
func mustConnectPgx(t testing.TB) *pgx.Conn {
config, err := pgx.ParseURI(os.Getenv("DATABASE_URL"))