pgxtype-experiment
Jack Christensen 2017-02-20 13:20:00 -06:00
parent 62f1adb342
commit 8b07d97d13
5 changed files with 176 additions and 27 deletions

View File

@ -101,6 +101,7 @@ func newWriteBuf(c *Conn, t byte) *WriteBuf {
// by the Encoder interface when implementing custom encoders.
type WriteBuf struct {
buf []byte
convBuf [8]byte
sizeIdx int
conn *Conn
}
@ -125,41 +126,32 @@ func (wb *WriteBuf) WriteCString(s string) {
}
func (wb *WriteBuf) WriteInt16(n int16) {
b := make([]byte, 2)
binary.BigEndian.PutUint16(b, uint16(n))
wb.buf = append(wb.buf, b...)
wb.WriteUint16(uint16(n))
}
func (wb *WriteBuf) WriteUint16(n uint16) (int, error) {
b := make([]byte, 2)
binary.BigEndian.PutUint16(b, n)
wb.buf = append(wb.buf, b...)
binary.BigEndian.PutUint16(wb.convBuf[:2], n)
wb.buf = append(wb.buf, wb.convBuf[:2]...)
return 2, nil
}
func (wb *WriteBuf) WriteInt32(n int32) {
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, uint32(n))
wb.buf = append(wb.buf, b...)
wb.WriteUint32(uint32(n))
}
func (wb *WriteBuf) WriteUint32(n uint32) (int, error) {
b := make([]byte, 4)
binary.BigEndian.PutUint32(b, n)
wb.buf = append(wb.buf, b...)
binary.BigEndian.PutUint32(wb.convBuf[:4], n)
wb.buf = append(wb.buf, wb.convBuf[:4]...)
return 4, nil
}
func (wb *WriteBuf) WriteInt64(n int64) {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, uint64(n))
wb.buf = append(wb.buf, b...)
wb.WriteUint64(uint64(n))
}
func (wb *WriteBuf) WriteUint64(n uint64) (int, error) {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, n)
wb.buf = append(wb.buf, b...)
binary.BigEndian.PutUint64(wb.convBuf[:8], n)
wb.buf = append(wb.buf, wb.convBuf[:8]...)
return 8, nil
}

View File

@ -1,7 +1,6 @@
package pgtype
import (
"encoding/binary"
"fmt"
"io"
"strconv"
@ -9,8 +8,23 @@ import (
type Int4 int32
func (i *Int4) ParseText(src string) error {
n, err := strconv.ParseInt(src, 10, 32)
func (i *Int4) DecodeText(r io.Reader) error {
size, err := ReadInt32(r)
if err != nil {
return err
}
if size == -1 {
return fmt.Errorf("invalid length for int4: %v", size)
}
buf := make([]byte, int(size))
_, err = r.Read(buf)
if err != nil {
return err
}
n, err := strconv.ParseInt(string(buf), 10, 32)
if err != nil {
return err
}
@ -19,12 +33,22 @@ func (i *Int4) ParseText(src string) error {
return nil
}
func (i *Int4) ParseBinary(src []byte) error {
if len(src) != 4 {
return fmt.Errorf("invalid length for int4: %v", len(src))
func (i *Int4) DecodeBinary(r io.Reader) error {
size, err := ReadInt32(r)
if err != nil {
return err
}
*i = Int4(binary.BigEndian.Uint32(src))
if size != 4 {
return fmt.Errorf("invalid length for int4: %v", size)
}
n, err := ReadInt32(r)
if err != nil {
return err
}
*i = Int4(n)
return nil
}

104
pgtype/typed_reader.go Normal file
View File

@ -0,0 +1,104 @@
package pgtype
import (
"encoding/binary"
"io"
)
type uint16Reader interface {
ReadUint16() (n uint16, err error)
}
type uint32Reader interface {
ReadUint32() (n uint32, err error)
}
type uint64Reader interface {
ReadUint64() (n uint64, err error)
}
// ReadByte reads a byte from r.
func ReadByte(r io.Reader) (byte, error) {
if r, ok := r.(io.ByteReader); ok {
return r.ReadByte()
}
buf := make([]byte, 1)
_, err := r.Read(buf)
return buf[0], err
}
// ReadUint16 reads an uint16 from r in PostgreSQL wire format (network byte order). This
// may be more efficient than directly using Read if r provides a ReadUint16
// method.
func ReadUint16(r io.Reader) (uint16, error) {
if r, ok := r.(uint16Reader); ok {
return r.ReadUint16()
}
buf := make([]byte, 2)
_, err := io.ReadFull(r, buf)
if err != nil {
return 0, err
}
return binary.BigEndian.Uint16(buf), nil
}
// ReadInt16 reads an int16 r in PostgreSQL wire format (network byte order). This
// may be more efficient than directly using Read if r provides a ReadUint16
// method.
func ReadInt16(r io.Reader) (int16, error) {
n, err := ReadUint16(r)
return int16(n), err
}
// ReadUint32 reads an uint32 r in PostgreSQL wire format (network byte order). This
// may be more efficient than directly using Read if r provides a ReadUint32
// method.
func ReadUint32(r io.Reader) (uint32, error) {
if r, ok := r.(uint32Reader); ok {
return r.ReadUint32()
}
buf := make([]byte, 4)
_, err := io.ReadFull(r, buf)
if err != nil {
return 0, err
}
return binary.BigEndian.Uint32(buf), nil
}
// ReadInt32 reads an int32 r in PostgreSQL wire format (network byte order). This
// may be more efficient than directly using Read if r provides a ReadUint32
// method.
func ReadInt32(r io.Reader) (int32, error) {
n, err := ReadUint32(r)
return int32(n), err
}
// ReadUint64 reads an uint64 r in PostgreSQL wire format (network byte order). This
// may be more efficient than directly using Read if r provides a ReadUint64
// method.
func ReadUint64(r io.Reader) (uint64, error) {
if r, ok := r.(uint64Reader); ok {
return r.ReadUint64()
}
buf := make([]byte, 8)
_, err := io.ReadFull(r, buf)
if err != nil {
return 0, err
}
return binary.BigEndian.Uint64(buf), nil
}
// ReadInt64 reads an int64 r in PostgreSQL wire format (network byte order). This
// may be more efficient than directly using Read if r provides a ReadUint64
// method.
func ReadInt64(r io.Reader) (int64, error) {
n, err := ReadUint64(r)
return int64(n), err
}

View File

@ -4,6 +4,8 @@ import (
"errors"
)
var errRewoundLen = errors.New("len was rewound")
// ValueReader is used by the Scanner interface to decode values.
type ValueReader struct {
mr *msgReader
@ -154,3 +156,28 @@ func (r *ValueReader) ReadBytes(count int32) []byte {
return r.mr.readBytes(count)
}
type valueReader2 struct {
*ValueReader
}
func (r *valueReader2) Read(dst []byte) (int, error) {
if r.err != nil {
return 0, r.err
}
src := r.ReadBytes(int32(len(dst)))
copy(dst, src)
return len(dst), nil
}
func (r *valueReader2) ReadUint32() (uint32, error) {
if r.err == errRewoundLen {
r.err = nil
return uint32(r.Len()), nil
}
return r.ValueReader.ReadUint32(), nil
}

View File

@ -1789,13 +1789,15 @@ func decodeInt4(vr *ValueReader) int32 {
return 0
}
vr.err = errRewoundLen
var n pgtype.Int4
var err error
switch vr.Type().FormatCode {
case TextFormatCode:
err = n.ParseText(vr.ReadString(vr.Len()))
err = n.DecodeText(&valueReader2{vr})
case BinaryFormatCode:
err = n.ParseBinary(vr.ReadBytes(vr.Len()))
err = n.DecodeBinary(&valueReader2{vr})
default:
vr.Fatal(ProtocolError(fmt.Sprintf("Unknown field description format code: %v", vr.Type().FormatCode)))
return 0