mirror of https://github.com/jackc/pgx.git
Made many things public so SelectFunc is actually usable by others
Definitely, need to add higher level methods for other packages to use. May rehide some of these interfaces at that point.pgx-vs-pq
parent
19d4a4d577
commit
78590be058
|
@ -61,7 +61,7 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) {
|
|||
|
||||
for {
|
||||
var t byte
|
||||
var r *messageReader
|
||||
var r *MessageReader
|
||||
if t, r, err = c.rxMsg(); err == nil {
|
||||
switch t {
|
||||
case backendKeyData:
|
||||
|
@ -93,17 +93,17 @@ func (c *Connection) Close() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *Connection) SelectFunc(sql string, onDataRow func(*messageReader, []fieldDescription) error) (err error) {
|
||||
func (c *Connection) SelectFunc(sql string, onDataRow func(*MessageReader, []FieldDescription) error) (err error) {
|
||||
if err = c.sendSimpleQuery(sql); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var callbackError error
|
||||
var fields []fieldDescription
|
||||
var fields []FieldDescription
|
||||
|
||||
for {
|
||||
var t byte
|
||||
var r *messageReader
|
||||
var r *MessageReader
|
||||
if t, r, err = c.rxMsg(); err == nil {
|
||||
switch t {
|
||||
case readyForQuery:
|
||||
|
@ -137,7 +137,7 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*messageReader, []fie
|
|||
// pattern when accessing the map
|
||||
func (c *Connection) SelectRows(sql string) (rows []map[string]string, err error) {
|
||||
rows = make([]map[string]string, 0, 8)
|
||||
onDataRow := func(r *messageReader, fields []fieldDescription) error {
|
||||
onDataRow := func(r *MessageReader, fields []FieldDescription) error {
|
||||
rows = append(rows, c.rxDataRow(r, fields))
|
||||
return nil
|
||||
}
|
||||
|
@ -164,7 +164,7 @@ func (c *Connection) Execute(sql string) (commandTag string, err error) {
|
|||
|
||||
for {
|
||||
var t byte
|
||||
var r *messageReader
|
||||
var r *MessageReader
|
||||
if t, r, err = c.rxMsg(); err == nil {
|
||||
switch t {
|
||||
case readyForQuery:
|
||||
|
@ -172,7 +172,7 @@ func (c *Connection) Execute(sql string) (commandTag string, err error) {
|
|||
case rowDescription:
|
||||
case dataRow:
|
||||
case commandComplete:
|
||||
commandTag = r.readString()
|
||||
commandTag = r.ReadString()
|
||||
default:
|
||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||
return
|
||||
|
@ -189,7 +189,7 @@ func (c *Connection) Execute(sql string) (commandTag string, err error) {
|
|||
// Processes messages that are not exclusive to one context such as
|
||||
// authentication or query response. The response to these messages
|
||||
// is the same regardless of when they occur.
|
||||
func (c *Connection) processContextFreeMsg(t byte, r *messageReader) (err error) {
|
||||
func (c *Connection) processContextFreeMsg(t byte, r *MessageReader) (err error) {
|
||||
switch t {
|
||||
case 'S':
|
||||
c.rxParameterStatus(r)
|
||||
|
@ -206,7 +206,7 @@ func (c *Connection) processContextFreeMsg(t byte, r *messageReader) (err error)
|
|||
|
||||
}
|
||||
|
||||
func (c *Connection) rxMsg() (t byte, r *messageReader, err error) {
|
||||
func (c *Connection) rxMsg() (t byte, r *MessageReader, err error) {
|
||||
var bodySize int32
|
||||
t, bodySize, err = c.rxMsgHeader()
|
||||
if err != nil {
|
||||
|
@ -239,14 +239,14 @@ func (c *Connection) rxMsgBody(bodySize int32) (buf []byte, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *Connection) rxAuthenticationX(r *messageReader) (err error) {
|
||||
code := r.readInt32()
|
||||
func (c *Connection) rxAuthenticationX(r *MessageReader) (err error) {
|
||||
code := r.ReadInt32()
|
||||
switch code {
|
||||
case 0: // AuthenticationOk
|
||||
case 3: // AuthenticationCleartextPassword
|
||||
c.txPasswordMessage(c.parameters.Password)
|
||||
case 5: // AuthenticationMD5Password
|
||||
salt := r.readByteString(4)
|
||||
salt := r.ReadByteString(4)
|
||||
digestedPassword := "md5" + hexMD5(hexMD5(c.parameters.Password+c.parameters.User)+salt)
|
||||
c.txPasswordMessage(digestedPassword)
|
||||
default:
|
||||
|
@ -262,75 +262,75 @@ func hexMD5(s string) string {
|
|||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
func (c *Connection) rxParameterStatus(r *messageReader) {
|
||||
key := r.readString()
|
||||
value := r.readString()
|
||||
func (c *Connection) rxParameterStatus(r *MessageReader) {
|
||||
key := r.ReadString()
|
||||
value := r.ReadString()
|
||||
c.runtimeParams[key] = value
|
||||
}
|
||||
|
||||
func (c *Connection) rxErrorResponse(r *messageReader) (err PgError) {
|
||||
func (c *Connection) rxErrorResponse(r *MessageReader) (err PgError) {
|
||||
for {
|
||||
switch r.readByte() {
|
||||
switch r.ReadByte() {
|
||||
case 'S':
|
||||
err.Severity = r.readString()
|
||||
err.Severity = r.ReadString()
|
||||
case 'C':
|
||||
err.Code = r.readString()
|
||||
err.Code = r.ReadString()
|
||||
case 'M':
|
||||
err.Message = r.readString()
|
||||
err.Message = r.ReadString()
|
||||
case 0: // End of error message
|
||||
return
|
||||
default: // Ignore other error fields
|
||||
r.readString()
|
||||
r.ReadString()
|
||||
}
|
||||
}
|
||||
|
||||
panic("Unreachable")
|
||||
}
|
||||
|
||||
func (c *Connection) rxBackendKeyData(r *messageReader) {
|
||||
c.pid = r.readInt32()
|
||||
c.secretKey = r.readInt32()
|
||||
func (c *Connection) rxBackendKeyData(r *MessageReader) {
|
||||
c.pid = r.ReadInt32()
|
||||
c.secretKey = r.ReadInt32()
|
||||
}
|
||||
|
||||
func (c *Connection) rxReadyForQuery(r *messageReader) {
|
||||
c.txStatus = r.readByte()
|
||||
func (c *Connection) rxReadyForQuery(r *MessageReader) {
|
||||
c.txStatus = r.ReadByte()
|
||||
}
|
||||
|
||||
func (c *Connection) rxRowDescription(r *messageReader) (fields []fieldDescription) {
|
||||
fieldCount := r.readInt16()
|
||||
fields = make([]fieldDescription, fieldCount)
|
||||
func (c *Connection) rxRowDescription(r *MessageReader) (fields []FieldDescription) {
|
||||
fieldCount := r.ReadInt16()
|
||||
fields = make([]FieldDescription, fieldCount)
|
||||
for i := int16(0); i < fieldCount; i++ {
|
||||
f := &fields[i]
|
||||
f.name = r.readString()
|
||||
f.table = r.readOid()
|
||||
f.attributeNumber = r.readInt16()
|
||||
f.dataType = r.readOid()
|
||||
f.dataTypeSize = r.readInt16()
|
||||
f.modifier = r.readInt32()
|
||||
f.formatCode = r.readInt16()
|
||||
f.Name = r.ReadString()
|
||||
f.Table = r.ReadOid()
|
||||
f.AttributeNumber = r.ReadInt16()
|
||||
f.DataType = r.ReadOid()
|
||||
f.DataTypeSize = r.ReadInt16()
|
||||
f.Modifier = r.ReadInt32()
|
||||
f.FormatCode = r.ReadInt16()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) rxDataRow(r *messageReader, fields []fieldDescription) (row map[string]string) {
|
||||
fieldCount := r.readInt16()
|
||||
func (c *Connection) rxDataRow(r *MessageReader, fields []FieldDescription) (row map[string]string) {
|
||||
fieldCount := r.ReadInt16()
|
||||
|
||||
row = make(map[string]string, fieldCount)
|
||||
for i := int16(0); i < fieldCount; i++ {
|
||||
size := r.readInt32()
|
||||
size := r.ReadInt32()
|
||||
if size > -1 {
|
||||
row[fields[i].name] = r.readByteString(size)
|
||||
row[fields[i].Name] = r.ReadByteString(size)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) rxDataRowFirstValue(r *messageReader) (s string, null bool) {
|
||||
r.readInt16() // ignore field count
|
||||
func (c *Connection) rxDataRowFirstValue(r *MessageReader) (s string, null bool) {
|
||||
r.ReadInt16() // ignore field count
|
||||
|
||||
size := r.readInt32()
|
||||
size := r.ReadInt32()
|
||||
if size > -1 {
|
||||
s = r.readByteString(size)
|
||||
s = r.ReadByteString(size)
|
||||
} else {
|
||||
null = true
|
||||
}
|
||||
|
@ -338,8 +338,8 @@ func (c *Connection) rxDataRowFirstValue(r *messageReader) (s string, null bool)
|
|||
return
|
||||
}
|
||||
|
||||
func (c *Connection) rxCommandComplete(r *messageReader) string {
|
||||
return r.readString()
|
||||
func (c *Connection) rxCommandComplete(r *MessageReader) string {
|
||||
return r.ReadString()
|
||||
}
|
||||
|
||||
func (c *Connection) txStartupMessage(msg *startupMessage) (err error) {
|
||||
|
|
|
@ -7,7 +7,7 @@ import (
|
|||
|
||||
func (c *Connection) SelectAllString(sql string) (strings []string, err error) {
|
||||
strings = make([]string, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) error {
|
||||
onDataRow := func(r *MessageReader, _ []FieldDescription) error {
|
||||
s, null := c.rxDataRowFirstValue(r)
|
||||
if null {
|
||||
return errors.New("Unexpected NULL")
|
||||
|
@ -21,7 +21,7 @@ func (c *Connection) SelectAllString(sql string) (strings []string, err error) {
|
|||
|
||||
func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) {
|
||||
ints = make([]int64, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r)
|
||||
if null {
|
||||
return errors.New("Unexpected NULL")
|
||||
|
@ -37,7 +37,7 @@ func (c *Connection) SelectAllInt64(sql string) (ints []int64, err error) {
|
|||
|
||||
func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) {
|
||||
ints = make([]int32, 0, 8)
|
||||
onDataRow := func(r *messageReader, fields []fieldDescription) (parseError error) {
|
||||
onDataRow := func(r *MessageReader, fields []FieldDescription) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r)
|
||||
if null {
|
||||
return errors.New("Unexpected NULL")
|
||||
|
@ -53,7 +53,7 @@ func (c *Connection) SelectAllInt32(sql string) (ints []int32, err error) {
|
|||
|
||||
func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) {
|
||||
ints = make([]int16, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r)
|
||||
if null {
|
||||
return errors.New("Unexpected NULL")
|
||||
|
@ -69,7 +69,7 @@ func (c *Connection) SelectAllInt16(sql string) (ints []int16, err error) {
|
|||
|
||||
func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error) {
|
||||
floats = make([]float64, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r)
|
||||
if null {
|
||||
return errors.New("Unexpected NULL")
|
||||
|
@ -85,7 +85,7 @@ func (c *Connection) SelectAllFloat64(sql string) (floats []float64, err error)
|
|||
|
||||
func (c *Connection) SelectAllFloat32(sql string) (floats []float32, err error) {
|
||||
floats = make([]float32, 0, 8)
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) (parseError error) {
|
||||
onDataRow := func(r *MessageReader, _ []FieldDescription) (parseError error) {
|
||||
s, null := c.rxDataRowFirstValue(r)
|
||||
if null {
|
||||
return errors.New("Unexpected NULL")
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
)
|
||||
|
||||
func (c *Connection) SelectString(sql string) (s string, err error) {
|
||||
onDataRow := func(r *messageReader, _ []fieldDescription) error {
|
||||
onDataRow := func(r *MessageReader, _ []FieldDescription) error {
|
||||
var null bool
|
||||
s, null = c.rxDataRowFirstValue(r)
|
||||
if null {
|
||||
|
|
|
@ -130,7 +130,7 @@ func TestSelectFunc(t *testing.T) {
|
|||
conn := getSharedConnection()
|
||||
|
||||
rowCount := 0
|
||||
onDataRow := func(r *messageReader, fields []fieldDescription) error {
|
||||
onDataRow := func(r *MessageReader, fields []FieldDescription) error {
|
||||
rowCount++
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,38 +5,38 @@ import (
|
|||
"encoding/binary"
|
||||
)
|
||||
|
||||
type messageReader []byte
|
||||
type MessageReader []byte
|
||||
|
||||
func newMessageReader(buf []byte) *messageReader {
|
||||
r := messageReader(buf)
|
||||
func newMessageReader(buf []byte) *MessageReader {
|
||||
r := MessageReader(buf)
|
||||
return &r
|
||||
}
|
||||
|
||||
func (r *messageReader) readByte() byte {
|
||||
func (r *MessageReader) ReadByte() byte {
|
||||
b := (*r)[0]
|
||||
*r = (*r)[1:]
|
||||
return b
|
||||
}
|
||||
|
||||
func (r *messageReader) readInt16() int16 {
|
||||
func (r *MessageReader) ReadInt16() int16 {
|
||||
n := int16(binary.BigEndian.Uint16((*r)[:2]))
|
||||
*r = (*r)[2:]
|
||||
return n
|
||||
}
|
||||
|
||||
func (r *messageReader) readInt32() int32 {
|
||||
func (r *MessageReader) ReadInt32() int32 {
|
||||
n := int32(binary.BigEndian.Uint32((*r)[:4]))
|
||||
*r = (*r)[4:]
|
||||
return n
|
||||
}
|
||||
|
||||
func (r *messageReader) readOid() oid {
|
||||
func (r *MessageReader) ReadOid() oid {
|
||||
n := oid(binary.BigEndian.Uint32((*r)[:4]))
|
||||
*r = (*r)[4:]
|
||||
return n
|
||||
}
|
||||
|
||||
func (r *messageReader) readString() string {
|
||||
func (r *MessageReader) ReadString() string {
|
||||
n := bytes.IndexByte(*r, 0)
|
||||
s := (*r)[:n]
|
||||
*r = (*r)[n+1:]
|
||||
|
@ -44,7 +44,7 @@ func (r *messageReader) readString() string {
|
|||
}
|
||||
|
||||
// Read count bytes and return as string
|
||||
func (r *messageReader) readByteString(count int32) string {
|
||||
func (r *MessageReader) ReadByteString(count int32) string {
|
||||
s := (*r)[:count]
|
||||
*r = (*r)[count:]
|
||||
return string(s)
|
||||
|
|
16
messages.go
16
messages.go
|
@ -43,14 +43,14 @@ func (self *startupMessage) Bytes() (buf []byte) {
|
|||
|
||||
type oid int32
|
||||
|
||||
type fieldDescription struct {
|
||||
name string
|
||||
table oid
|
||||
attributeNumber int16
|
||||
dataType oid
|
||||
dataTypeSize int16
|
||||
modifier int32
|
||||
formatCode int16
|
||||
type FieldDescription struct {
|
||||
Name string
|
||||
Table oid
|
||||
AttributeNumber int16
|
||||
DataType oid
|
||||
DataTypeSize int16
|
||||
Modifier int32
|
||||
FormatCode int16
|
||||
}
|
||||
|
||||
type PgError struct {
|
||||
|
|
Loading…
Reference in New Issue