Remove MessageWriter

scan-io
Jack Christensen 2014-06-19 18:10:04 -05:00
parent 3b9a1ce659
commit 772c6ca7d7
3 changed files with 92 additions and 113 deletions

61
conn.go
View File

@ -588,27 +588,21 @@ func (c *Conn) Prepare(name, sql string) (err error) {
// parse // parse
buf := c.getBuf() buf := c.getBuf()
w := newMessageWriter(buf) buf.WriteString(name)
w.WriteCString(name) buf.WriteByte(0)
w.WriteCString(sql) buf.WriteString(sql)
w.Write(int16(0)) buf.WriteByte(0)
if w.Err != nil { binary.Write(buf, binary.BigEndian, int16(0))
return w.Err
}
err = c.txMsg('P', buf, false) err = c.txMsg('P', buf, false)
if err != nil { if err != nil {
return return err
} }
// describe // describe
buf = c.getBuf() buf = c.getBuf()
w = newMessageWriter(buf) buf.WriteByte('S')
w.WriteByte('S') buf.WriteString(name)
w.WriteCString(name) buf.WriteByte(0)
if w.Err != nil {
return w.Err
}
err = c.txMsg('D', buf, false) err = c.txMsg('D', buf, false)
if err != nil { if err != nil {
return return
@ -774,46 +768,44 @@ func (c *Conn) sendPreparedQuery(ps *preparedStatement, arguments ...interface{}
// bind // bind
buf := c.getBuf() buf := c.getBuf()
w := newMessageWriter(buf) buf.WriteString("")
w.WriteCString("") buf.WriteByte(0)
w.WriteCString(ps.Name) buf.WriteString(ps.Name)
w.Write(int16(len(ps.ParameterOids))) buf.WriteByte(0)
binary.Write(buf, binary.BigEndian, int16(len(ps.ParameterOids)))
for _, oid := range ps.ParameterOids { for _, oid := range ps.ParameterOids {
transcoder := ValueTranscoders[oid] transcoder := ValueTranscoders[oid]
if transcoder == nil { if transcoder == nil {
transcoder = defaultTranscoder transcoder = defaultTranscoder
} }
w.Write(transcoder.EncodeFormat) binary.Write(buf, binary.BigEndian, transcoder.EncodeFormat)
} }
w.Write(int16(len(arguments))) binary.Write(buf, binary.BigEndian, int16(len(arguments)))
for i, oid := range ps.ParameterOids { for i, oid := range ps.ParameterOids {
if arguments[i] != nil { if arguments[i] != nil {
transcoder := ValueTranscoders[oid] transcoder := ValueTranscoders[oid]
if transcoder == nil { if transcoder == nil {
transcoder = defaultTranscoder transcoder = defaultTranscoder
} }
err = transcoder.EncodeTo(w.buf, arguments[i]) err = transcoder.EncodeTo(buf, arguments[i])
if err != nil { if err != nil {
return err return err
} }
} else { } else {
w.Write(int32(-1)) binary.Write(buf, binary.BigEndian, int32(-1))
} }
} }
w.Write(int16(len(ps.FieldDescriptions))) binary.Write(buf, binary.BigEndian, int16(len(ps.FieldDescriptions)))
for _, fd := range ps.FieldDescriptions { for _, fd := range ps.FieldDescriptions {
transcoder := ValueTranscoders[fd.DataType] transcoder := ValueTranscoders[fd.DataType]
if transcoder != nil && transcoder.DecodeBinary != nil { if transcoder != nil && transcoder.DecodeBinary != nil {
w.Write(int16(1)) binary.Write(buf, binary.BigEndian, int16(1))
} else { } else {
w.Write(int16(0)) binary.Write(buf, binary.BigEndian, int16(0))
} }
} }
if w.Err != nil {
return w.Err
}
err = c.txMsg('B', buf, false) err = c.txMsg('B', buf, false)
if err != nil { if err != nil {
@ -822,14 +814,9 @@ func (c *Conn) sendPreparedQuery(ps *preparedStatement, arguments ...interface{}
// execute // execute
buf = c.getBuf() buf = c.getBuf()
w = newMessageWriter(buf) buf.WriteString("")
w.WriteCString("") buf.WriteByte(0)
w.Write(int32(0)) binary.Write(buf, binary.BigEndian, int32(0))
if w.Err != nil {
return w.Err
}
err = c.txMsg('E', buf, false) err = c.txMsg('E', buf, false)
if err != nil { if err != nil {
return err return err

View File

@ -1,58 +0,0 @@
package pgx
import (
"bytes"
"encoding/binary"
)
// MessageWriter is a helper for producing messages to send to PostgreSQL.
// To avoid verbose error handling it internally records errors and no-ops
// any calls that occur after an error. At the end of a sequence of writes
// the Err field should be checked to see if any errors occurred.
type MessageWriter struct {
buf *bytes.Buffer
Err error
}
func newMessageWriter(buf *bytes.Buffer) *MessageWriter {
return &MessageWriter{buf: buf}
}
// WriteCString writes a null-terminated string.
func (w *MessageWriter) WriteCString(s string) {
if w.Err != nil {
return
}
if _, w.Err = w.buf.WriteString(s); w.Err != nil {
return
}
w.Err = w.buf.WriteByte(0)
}
// WriteString writes a string without a null terminator.
func (w *MessageWriter) WriteString(s string) {
if w.Err != nil {
return
}
if _, w.Err = w.buf.WriteString(s); w.Err != nil {
return
}
}
func (w *MessageWriter) WriteByte(b byte) {
if w.Err != nil {
return
}
w.Err = w.buf.WriteByte(b)
}
// Write writes data in the network byte order. data can be an integer type,
// float type, or byte slice.
func (w *MessageWriter) Write(data interface{}) {
if w.Err != nil {
return
}
w.Err = binary.Write(w.buf, binary.BigEndian, data)
}

View File

@ -462,16 +462,32 @@ func decodeInt2ArrayFromText(mr *MessageReader, size int32) interface{} {
} }
func int16SliceToArrayString(nums []int16) (string, error) { func int16SliceToArrayString(nums []int16) (string, error) {
w := newMessageWriter(&bytes.Buffer{}) w := &bytes.Buffer{}
w.WriteString("{") _, err := w.WriteString("{")
if err != nil {
return "", err
}
for i, n := range nums { for i, n := range nums {
if i > 0 { if i > 0 {
w.WriteString(",") _, err = w.WriteString(",")
if err != nil {
return "", err
}
}
_, err = w.WriteString(strconv.FormatInt(int64(n), 10))
if err != nil {
return "", err
} }
w.WriteString(strconv.FormatInt(int64(n), 10))
} }
w.WriteString("}")
return w.buf.String(), w.Err _, err = w.WriteString("}")
if err != nil {
return "", err
}
return w.String(), nil
} }
func encodeInt2Array(w io.Writer, value interface{}) error { func encodeInt2Array(w io.Writer, value interface{}) error {
@ -513,16 +529,33 @@ func decodeInt4ArrayFromText(mr *MessageReader, size int32) interface{} {
} }
func int32SliceToArrayString(nums []int32) (string, error) { func int32SliceToArrayString(nums []int32) (string, error) {
w := newMessageWriter(&bytes.Buffer{}) w := &bytes.Buffer{}
w.WriteString("{")
_, err := w.WriteString("{")
if err != nil {
return "", err
}
for i, n := range nums { for i, n := range nums {
if i > 0 { if i > 0 {
w.WriteString(",") _, err = w.WriteString(",")
if err != nil {
return "", err
}
}
_, err = w.WriteString(strconv.FormatInt(int64(n), 10))
if err != nil {
return "", err
} }
w.WriteString(strconv.FormatInt(int64(n), 10))
} }
w.WriteString("}")
return w.buf.String(), w.Err _, err = w.WriteString("}")
if err != nil {
return "", err
}
return w.String(), nil
} }
func encodeInt4Array(w io.Writer, value interface{}) error { func encodeInt4Array(w io.Writer, value interface{}) error {
@ -564,16 +597,33 @@ func decodeInt8ArrayFromText(mr *MessageReader, size int32) interface{} {
} }
func int64SliceToArrayString(nums []int64) (string, error) { func int64SliceToArrayString(nums []int64) (string, error) {
w := newMessageWriter(&bytes.Buffer{}) w := &bytes.Buffer{}
w.WriteString("{")
_, err := w.WriteString("{")
if err != nil {
return "", err
}
for i, n := range nums { for i, n := range nums {
if i > 0 { if i > 0 {
w.WriteString(",") _, err = w.WriteString(",")
if err != nil {
return "", err
}
}
_, err = w.WriteString(strconv.FormatInt(int64(n), 10))
if err != nil {
return "", err
} }
w.WriteString(strconv.FormatInt(int64(n), 10))
} }
w.WriteString("}")
return w.buf.String(), w.Err _, err = w.WriteString("}")
if err != nil {
return "", err
}
return w.String(), nil
} }
func encodeInt8Array(w io.Writer, value interface{}) error { func encodeInt8Array(w io.Writer, value interface{}) error {