diff --git a/messages.go b/messages.go index c2964b82..2320bf7f 100644 --- a/messages.go +++ b/messages.go @@ -130,10 +130,11 @@ func (wb *WriteBuf) WriteInt16(n int16) { wb.buf = append(wb.buf, b...) } -func (wb *WriteBuf) WriteUint16(n uint16) { +func (wb *WriteBuf) WriteUint16(n uint16) (int, error) { b := make([]byte, 2) binary.BigEndian.PutUint16(b, n) wb.buf = append(wb.buf, b...) + return 2, nil } func (wb *WriteBuf) WriteInt32(n int32) { @@ -142,10 +143,11 @@ func (wb *WriteBuf) WriteInt32(n int32) { wb.buf = append(wb.buf, b...) } -func (wb *WriteBuf) WriteUint32(n uint32) { +func (wb *WriteBuf) WriteUint32(n uint32) (int, error) { b := make([]byte, 4) binary.BigEndian.PutUint32(b, n) wb.buf = append(wb.buf, b...) + return 4, nil } func (wb *WriteBuf) WriteInt64(n int64) { @@ -154,6 +156,18 @@ func (wb *WriteBuf) WriteInt64(n int64) { wb.buf = append(wb.buf, b...) } +func (wb *WriteBuf) WriteUint64(n uint64) (int, error) { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, n) + wb.buf = append(wb.buf, b...) + return 8, nil +} + func (wb *WriteBuf) WriteBytes(b []byte) { wb.buf = append(wb.buf, b...) } + +func (wb *WriteBuf) Write(b []byte) (int, error) { + wb.buf = append(wb.buf, b...) + return len(b), nil +} diff --git a/pgtype/int4.go b/pgtype/int4.go index cd0f1ed2..ab993476 100644 --- a/pgtype/int4.go +++ b/pgtype/int4.go @@ -3,6 +3,7 @@ package pgtype import ( "encoding/binary" "fmt" + "io" "strconv" ) @@ -27,12 +28,22 @@ func (i *Int4) ParseBinary(src []byte) error { return nil } -func (i Int4) FormatText() (string, error) { - return strconv.FormatInt(int64(i), 10), nil +func (i Int4) EncodeText(w io.Writer) error { + s := strconv.FormatInt(int64(i), 10) + _, err := WriteInt32(w, int32(len(s))) + if err != nil { + return nil + } + _, err = w.Write([]byte(s)) + return err } -func (i Int4) FormatBinary() ([]byte, error) { - buf := make([]byte, 4) - binary.BigEndian.PutUint32(buf, uint32(i)) - return buf, nil +func (i Int4) EncodeBinary(w io.Writer) error { + _, err := WriteInt32(w, 4) + if err != nil { + return err + } + + _, err = WriteInt32(w, int32(i)) + return err } diff --git a/pgtype/typed_writer.go b/pgtype/typed_writer.go new file mode 100644 index 00000000..3f175343 --- /dev/null +++ b/pgtype/typed_writer.go @@ -0,0 +1,97 @@ +package pgtype + +import ( + "encoding/binary" + "io" +) + +type uint16Writer interface { + WriteUint16(uint16) (n int, err error) +} + +type uint32Writer interface { + WriteUint32(uint32) (n int, err error) +} + +type uint64Writer interface { + WriteUint64(uint64) (n int, err error) +} + +// WriteByte writes b to w. +func WriteByte(w io.Writer, b byte) error { + if w, ok := w.(io.ByteWriter); ok { + return w.WriteByte(b) + } + _, err := w.Write([]byte{b}) + return err +} + +// WriteUint16 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint16 +// method. +func WriteUint16(w io.Writer, n uint16) (int, error) { + if w, ok := w.(uint16Writer); ok { + return w.WriteUint16(n) + } + b := make([]byte, 2) + binary.BigEndian.PutUint16(b, n) + return w.Write(b) +} + +// WriteInt16 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint16 +// method. +func WriteInt16(w io.Writer, n int16) (int, error) { + return WriteUint16(w, uint16(n)) +} + +// WriteUint32 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint32 +// method. +func WriteUint32(w io.Writer, n uint32) (int, error) { + if w, ok := w.(uint32Writer); ok { + return w.WriteUint32(n) + } + b := make([]byte, 4) + binary.BigEndian.PutUint32(b, n) + return w.Write(b) +} + +// WriteInt32 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint32 +// method. +func WriteInt32(w io.Writer, n int32) (int, error) { + return WriteUint32(w, uint32(n)) +} + +// WriteUint64 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint64 +// method. +func WriteUint64(w io.Writer, n uint64) (int, error) { + if w, ok := w.(uint64Writer); ok { + return w.WriteUint64(n) + } + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, n) + return w.Write(b) +} + +// WriteInt64 writes n to w in PostgreSQL wire format (network byte order). This +// may be more efficient than directly using Write if w provides a WriteUint64 +// method. +func WriteInt64(w io.Writer, n int64) (int, error) { + return WriteUint64(w, uint64(n)) +} + +// WriteCString writes s to w followed by a null byte. +func WriteCString(w io.Writer, s string) (int, error) { + n, err := io.WriteString(w, s) + if err != nil { + return n, err + } + err = WriteByte(w, 0) + if err != nil { + return n, err + } + return n + 1, nil +} diff --git a/values.go b/values.go index 8650c2da..ee53edfc 100644 --- a/values.go +++ b/values.go @@ -1681,12 +1681,10 @@ func encodeInt32(w *WriteBuf, oid OID, value int32) error { return fmt.Errorf("%d is greater than max int16 %d", value, math.MaxInt16) } case Int4OID: - w.WriteInt32(4) - buf, err := pgtype.Int4(value).FormatBinary() + err := pgtype.Int4(value).EncodeBinary(w) if err != nil { return err } - w.WriteBytes(buf) case Int8OID: w.WriteInt32(8) w.WriteInt64(int64(value))