package pgx import ( "bytes" "encoding/binary" "errors" "net" "github.com/jackc/pgx/chunkreader" ) // msgReader is a helper that reads values from a PostgreSQL message. type msgReader struct { cr *chunkreader.ChunkReader msgType byte msgBody []byte rp int // read position err error log func(lvl int, msg string, ctx ...interface{}) shouldLog func(lvl int) bool } // fatal tells rc that a Fatal error has occurred func (r *msgReader) fatal(err error) { if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.fatal", "error", err, "msgType", r.msgType, "msgBody", r.msgBody, "rp", r.rp) } r.err = err } // rxMsg reads the type and size of the next message. func (r *msgReader) rxMsg() (byte, error) { if r.err != nil { return 0, r.err } header, err := r.cr.Next(5) if err != nil { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { r.fatal(err) } return 0, err } r.msgType = header[0] bodyLen := int(binary.BigEndian.Uint32(header[1:])) - 4 r.msgBody, err = r.cr.Next(bodyLen) if err != nil { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { r.fatal(err) } return 0, err } r.rp = 0 return r.msgType, nil } func (r *msgReader) readByte() byte { if r.err != nil { return 0 } if len(r.msgBody)-r.rp < 1 { r.fatal(errors.New("read past end of message")) return 0 } b := r.msgBody[r.rp] r.rp++ if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readByte", "value", b, "byteAsString", string(b), "msgType", r.msgType, "rp", r.rp) } return b } func (r *msgReader) readInt16() int16 { if r.err != nil { return 0 } if len(r.msgBody)-r.rp < 2 { r.fatal(errors.New("read past end of message")) return 0 } n := int16(binary.BigEndian.Uint16(r.msgBody[r.rp:])) r.rp += 2 if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readInt16", "value", n, "msgType", r.msgType, "rp", r.rp) } return n } func (r *msgReader) readInt32() int32 { if r.err != nil { return 0 } if len(r.msgBody)-r.rp < 4 { r.fatal(errors.New("read past end of message")) return 0 } n := int32(binary.BigEndian.Uint32(r.msgBody[r.rp:])) r.rp += 4 if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readInt32", "value", n, "msgType", r.msgType, "rp", r.rp) } return n } func (r *msgReader) readUint16() uint16 { if r.err != nil { return 0 } if len(r.msgBody)-r.rp < 2 { r.fatal(errors.New("read past end of message")) return 0 } n := binary.BigEndian.Uint16(r.msgBody[r.rp:]) r.rp += 2 if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readUint16", "value", n, "msgType", r.msgType, "rp", r.rp) } return n } func (r *msgReader) readUint32() uint32 { if r.err != nil { return 0 } if len(r.msgBody)-r.rp < 4 { r.fatal(errors.New("read past end of message")) return 0 } n := binary.BigEndian.Uint32(r.msgBody[r.rp:]) r.rp += 4 if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readUint32", "value", n, "msgType", r.msgType, "rp", r.rp) } return n } func (r *msgReader) readInt64() int64 { if r.err != nil { return 0 } if len(r.msgBody)-r.rp < 8 { r.fatal(errors.New("read past end of message")) return 0 } n := int64(binary.BigEndian.Uint64(r.msgBody[r.rp:])) r.rp += 8 if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readInt64", "value", n, "msgType", r.msgType, "rp", r.rp) } return n } // readCString reads a null terminated string func (r *msgReader) readCString() string { if r.err != nil { return "" } nullIdx := bytes.IndexByte(r.msgBody[r.rp:], 0) if nullIdx == -1 { r.fatal(errors.New("null terminated string not found")) return "" } s := string(r.msgBody[r.rp : r.rp+nullIdx]) r.rp += nullIdx + 1 if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readCString", "value", s, "msgType", r.msgType, "rp", r.rp) } return s } // readString reads count bytes and returns as string func (r *msgReader) readString(countI32 int32) string { if r.err != nil { return "" } count := int(countI32) if len(r.msgBody)-r.rp < count { r.fatal(errors.New("read past end of message")) return "" } s := string(r.msgBody[r.rp : r.rp+count]) r.rp += count if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readString", "value", s, "msgType", r.msgType, "rp", r.rp) } return s } // readBytes reads count bytes and returns as []byte func (r *msgReader) readBytes(countI32 int32) []byte { if r.err != nil { return nil } count := int(countI32) if len(r.msgBody)-r.rp < count { r.fatal(errors.New("read past end of message")) return nil } b := r.msgBody[r.rp : r.rp+count] r.rp += count r.cr.KeepLast() if r.shouldLog(LogLevelTrace) { r.log(LogLevelTrace, "msgReader.readBytes", "value", b, r.msgType, "rp", r.rp) } return b }