From bac4af13bb9df8725ea56e4aa709f8ad17bd7a0d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 Mar 2017 21:07:40 -0600 Subject: [PATCH] Add bytea --- pgtype/bytea.go | 160 ++++++++++++++++++++++++++++++++++++++++++ pgtype/bytea_test.go | 73 +++++++++++++++++++ pgtype/convert.go | 21 ++++++ pgtype/pgtype_test.go | 1 + 4 files changed, 255 insertions(+) create mode 100644 pgtype/bytea.go create mode 100644 pgtype/bytea_test.go diff --git a/pgtype/bytea.go b/pgtype/bytea.go new file mode 100644 index 00000000..2532182f --- /dev/null +++ b/pgtype/bytea.go @@ -0,0 +1,160 @@ +package pgtype + +import ( + "encoding/hex" + "fmt" + "io" + "reflect" + + "github.com/jackc/pgx/pgio" +) + +type Bytea struct { + Bytes []byte + Status Status +} + +func (dst *Bytea) ConvertFrom(src interface{}) error { + switch value := src.(type) { + case Bytea: + *dst = value + case []byte: + if value != nil { + *dst = Bytea{Bytes: value, Status: Present} + } else { + *dst = Bytea{Status: Null} + } + default: + if originalSrc, ok := underlyingBytesType(src); ok { + return dst.ConvertFrom(originalSrc) + } + return fmt.Errorf("cannot convert %v to Bytea", value) + } + + return nil +} + +func (src *Bytea) AssignTo(dst interface{}) error { + switch v := dst.(type) { + case *[]byte: + if src.Status == Present { + *v = src.Bytes + } else { + *v = nil + } + default: + if v := reflect.ValueOf(dst); v.Kind() == reflect.Ptr { + el := v.Elem() + switch el.Kind() { + // if dst is a pointer to pointer, strip the pointer and try again + case reflect.Ptr: + if src.Status == Null { + el.Set(reflect.Zero(el.Type())) + return nil + } + if el.IsNil() { + // allocate destination + el.Set(reflect.New(el.Type().Elem())) + } + return src.AssignTo(el.Interface()) + default: + if originalDst, ok := underlyingPtrSliceType(dst); ok { + return src.AssignTo(originalDst) + } + } + } + return fmt.Errorf("cannot decode %v into %T", src, dst) + } + + return nil +} + +// DecodeText only supports the hex format. This has been the default since +// PostgreSQL 9.0. +func (dst *Bytea) DecodeText(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Bytea{Status: Null} + return nil + } + + sbuf := make([]byte, int(size)) + _, err = io.ReadFull(r, sbuf) + if err != nil { + return err + } + + if len(sbuf) < 2 || sbuf[0] != '\\' || sbuf[1] != 'x' { + return fmt.Errorf("invalid hex format") + } + + buf := make([]byte, (len(sbuf)-2)/2) + _, err = hex.Decode(buf, sbuf[2:]) + if err != nil { + return err + } + + *dst = Bytea{Bytes: buf, Status: Present} + return nil +} + +func (dst *Bytea) DecodeBinary(r io.Reader) error { + size, err := pgio.ReadInt32(r) + if err != nil { + return err + } + + if size == -1 { + *dst = Bytea{Status: Null} + return nil + } + + buf := make([]byte, int(size)) + + _, err = io.ReadFull(r, buf) + if err != nil { + return err + } + + *dst = Bytea{Bytes: buf, Status: Present} + return nil +} + +func (src Bytea) EncodeText(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + str := hex.EncodeToString(src.Bytes) + + _, err := pgio.WriteInt32(w, int32(len(str)+2)) + if err != nil { + return nil + } + + _, err = io.WriteString(w, `\x`) + if err != nil { + return nil + } + + _, err = io.WriteString(w, str) + return err +} + +func (src Bytea) EncodeBinary(w io.Writer) error { + if done, err := encodeNotPresent(w, src.Status); done { + return err + } + + _, err := pgio.WriteInt32(w, int32(len(src.Bytes))) + if err != nil { + return nil + } + + _, err = w.Write(src.Bytes) + return err +} diff --git a/pgtype/bytea_test.go b/pgtype/bytea_test.go new file mode 100644 index 00000000..51941387 --- /dev/null +++ b/pgtype/bytea_test.go @@ -0,0 +1,73 @@ +package pgtype_test + +import ( + "reflect" + "testing" + + "github.com/jackc/pgx/pgtype" +) + +func TestByteaTranscode(t *testing.T) { + testSuccessfulTranscode(t, "bytea", []interface{}{ + pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}, + pgtype.Bytea{Bytes: nil, Status: pgtype.Null}, + }) +} + +func TestByteaConvertFrom(t *testing.T) { + successfulTests := []struct { + source interface{} + result pgtype.Bytea + }{ + {source: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Null}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Null}}, + {source: []byte{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + {source: []byte{}, result: pgtype.Bytea{Bytes: []byte{}, Status: pgtype.Present}}, + {source: []byte(nil), result: pgtype.Bytea{Status: pgtype.Null}}, + {source: _byteSlice{1, 2, 3}, result: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}}, + {source: _byteSlice(nil), result: pgtype.Bytea{Status: pgtype.Null}}, + } + + for i, tt := range successfulTests { + var r pgtype.Bytea + err := r.ConvertFrom(tt.source) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if !reflect.DeepEqual(r, tt.result) { + t.Errorf("%d: expected %v to convert to %v, but it was %v", i, tt.source, tt.result, r) + } + } +} + +func TestByteaAssignTo(t *testing.T) { + var buf []byte + var _buf _byteSlice + var pbuf *[]byte + var _pbuf *_byteSlice + + simpleTests := []struct { + src pgtype.Bytea + dst interface{} + expected interface{} + }{ + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &buf, expected: []byte{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_buf, expected: _byteSlice{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &pbuf, expected: &[]byte{1, 2, 3}}, + {src: pgtype.Bytea{Bytes: []byte{1, 2, 3}, Status: pgtype.Present}, dst: &_pbuf, expected: &_byteSlice{1, 2, 3}}, + {src: pgtype.Bytea{Status: pgtype.Null}, dst: &pbuf, expected: ((*[]byte)(nil))}, + {src: pgtype.Bytea{Status: pgtype.Null}, dst: &_pbuf, expected: ((*_byteSlice)(nil))}, + } + + for i, tt := range simpleTests { + err := tt.src.AssignTo(tt.dst) + if err != nil { + t.Errorf("%d: %v", i, err) + } + + if dst := reflect.ValueOf(tt.dst).Elem().Interface(); !reflect.DeepEqual(dst, tt.expected) { + t.Errorf("%d: expected %v to assign %v, but result was %v", i, tt.src, tt.expected, dst) + } + } +} diff --git a/pgtype/convert.go b/pgtype/convert.go index 31bbf060..648209f5 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -85,6 +85,27 @@ func underlyingBoolType(val interface{}) (interface{}, bool) { return nil, false } +// underlyingBytesType gets the underlying type that can be converted to []byte +func underlyingBytesType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return nil, false + } + convVal := refVal.Elem().Interface() + return convVal, true + case reflect.Slice: + if refVal.Type().Elem().Kind() == reflect.Uint8 { + convVal := refVal.Bytes() + return convVal, reflect.TypeOf(convVal) != refVal.Type() + } + } + + return nil, false +} + // underlyingStringType gets the underlying type that can be converted to String func underlyingStringType(val interface{}) (interface{}, bool) { refVal := reflect.ValueOf(val) diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index c1dba383..6e173cbe 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -22,6 +22,7 @@ type _int32Slice []int32 type _int64Slice []int64 type _float32Slice []float32 type _float64Slice []float64 +type _byteSlice []byte func mustConnectPgx(t testing.TB) *pgx.Conn { config, err := pgx.ParseURI(os.Getenv("DATABASE_URL"))