mirror of https://github.com/jackc/pgx.git
Dirty, but somewhat working prepared statements and extended protocol
parent
1042f095ee
commit
5073a3b9e0
200
connection.go
200
connection.go
|
@ -9,6 +9,8 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnectionParameters struct {
|
type ConnectionParameters struct {
|
||||||
|
@ -28,6 +30,13 @@ type Connection struct {
|
||||||
runtimeParams map[string]string // parameters that have been reported by the server
|
runtimeParams map[string]string // parameters that have been reported by the server
|
||||||
parameters ConnectionParameters // parameters used when establishing this connection
|
parameters ConnectionParameters // parameters used when establishing this connection
|
||||||
txStatus byte
|
txStatus byte
|
||||||
|
preparedStatements map[string]*PreparedStatement
|
||||||
|
}
|
||||||
|
|
||||||
|
type PreparedStatement struct {
|
||||||
|
Name string
|
||||||
|
FieldDescriptions []FieldDescription
|
||||||
|
ParameterOids []oid
|
||||||
}
|
}
|
||||||
|
|
||||||
type NotSingleRowError struct {
|
type NotSingleRowError struct {
|
||||||
|
@ -71,6 +80,7 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) {
|
||||||
|
|
||||||
c.buf = bytes.NewBuffer(make([]byte, sharedBufferSize))
|
c.buf = bytes.NewBuffer(make([]byte, sharedBufferSize))
|
||||||
c.runtimeParams = make(map[string]string)
|
c.runtimeParams = make(map[string]string)
|
||||||
|
c.preparedStatements = make(map[string]*PreparedStatement)
|
||||||
|
|
||||||
msg := newStartupMessage()
|
msg := newStartupMessage()
|
||||||
msg.options["user"] = c.parameters.User
|
msg.options["user"] = c.parameters.User
|
||||||
|
@ -108,12 +118,19 @@ func (c *Connection) Close() (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error, arguments ...interface{}) (err error) {
|
func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error, arguments ...interface{}) (err error) {
|
||||||
if err = c.sendSimpleQuery(sql, arguments...); err != nil {
|
var fields []FieldDescription
|
||||||
|
|
||||||
|
if ps, present := c.preparedStatements[sql]; present {
|
||||||
|
fields = ps.FieldDescriptions
|
||||||
|
err = c.sendPreparedQuery(ps, arguments...)
|
||||||
|
} else {
|
||||||
|
err = c.sendSimpleQuery(sql, arguments...)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var callbackError error
|
var callbackError error
|
||||||
var fields []FieldDescription
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
var t byte
|
var t byte
|
||||||
|
@ -132,6 +149,7 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error
|
||||||
callbackError = onDataRow(newDataRowReader(r, fields))
|
callbackError = onDataRow(newDataRowReader(r, fields))
|
||||||
}
|
}
|
||||||
case commandComplete:
|
case commandComplete:
|
||||||
|
case bindComplete:
|
||||||
default:
|
default:
|
||||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||||
return
|
return
|
||||||
|
@ -207,6 +225,101 @@ func (c *Connection) SelectValues(sql string, arguments ...interface{}) (values
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Connection) Prepare(name, sql string) (err error) {
|
||||||
|
// parse
|
||||||
|
buf := c.getBuf()
|
||||||
|
_, err = buf.WriteString(name)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = buf.WriteByte(0)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = buf.WriteString(sql)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = buf.WriteByte(0)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = binary.Write(buf, binary.BigEndian, int16(0))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.txMsg('P', buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// describe
|
||||||
|
buf = c.getBuf()
|
||||||
|
err = buf.WriteByte('S')
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = buf.WriteString(name)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = buf.WriteByte(0)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.txMsg('D', buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// sync
|
||||||
|
err = c.txMsg('S', c.getBuf())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ps := PreparedStatement{Name: name}
|
||||||
|
|
||||||
|
for {
|
||||||
|
var t byte
|
||||||
|
var r *MessageReader
|
||||||
|
if t, r, err = c.rxMsg(); err == nil {
|
||||||
|
switch t {
|
||||||
|
case parseComplete:
|
||||||
|
case parameterDescription:
|
||||||
|
ps.ParameterOids = c.rxParameterDescription(r)
|
||||||
|
case rowDescription:
|
||||||
|
ps.FieldDescriptions = c.rxRowDescription(r)
|
||||||
|
case readyForQuery:
|
||||||
|
c.preparedStatements[name] = &ps
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) Deallocate(name string) (err error) {
|
||||||
|
delete(c.preparedStatements, name)
|
||||||
|
_, err = c.Execute("deallocate " + c.QuoteIdentifier(name))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) sendQuery(sql string, arguments ...interface{}) (err error) {
|
||||||
|
if ps, present := c.preparedStatements[sql]; present {
|
||||||
|
return c.sendPreparedQuery(ps, arguments...)
|
||||||
|
} else {
|
||||||
|
return c.sendSimpleQuery(sql, arguments...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err error) {
|
func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err error) {
|
||||||
if len(arguments) > 0 {
|
if len(arguments) > 0 {
|
||||||
sql = c.SanitizeSql(sql, arguments...)
|
sql = c.SanitizeSql(sql, arguments...)
|
||||||
|
@ -226,8 +339,78 @@ func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err
|
||||||
return c.txMsg('Q', buf)
|
return c.txMsg('Q', buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Connection) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) {
|
||||||
|
if len(ps.ParameterOids) != len(arguments) {
|
||||||
|
return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments))
|
||||||
|
}
|
||||||
|
|
||||||
|
// bind
|
||||||
|
buf := c.getBuf()
|
||||||
|
buf.WriteString("")
|
||||||
|
buf.WriteByte(0)
|
||||||
|
buf.WriteString(ps.Name)
|
||||||
|
buf.WriteByte(0)
|
||||||
|
binary.Write(buf, binary.BigEndian, int16(0))
|
||||||
|
binary.Write(buf, binary.BigEndian, int16(len(arguments)))
|
||||||
|
for _, iArg := range arguments {
|
||||||
|
var s string
|
||||||
|
switch arg := iArg.(type) {
|
||||||
|
case string:
|
||||||
|
s = arg
|
||||||
|
case int16:
|
||||||
|
s = strconv.FormatInt(int64(arg), 10)
|
||||||
|
case int32:
|
||||||
|
s = strconv.FormatInt(int64(arg), 10)
|
||||||
|
case int64:
|
||||||
|
s = strconv.FormatInt(int64(arg), 10)
|
||||||
|
case float32:
|
||||||
|
s = strconv.FormatFloat(float64(arg), 'f', -1, 32)
|
||||||
|
case float64:
|
||||||
|
s = strconv.FormatFloat(arg, 'f', -1, 64)
|
||||||
|
case []byte:
|
||||||
|
s = `E'\\x` + hex.EncodeToString(arg) + `'`
|
||||||
|
default:
|
||||||
|
panic("Unable to encode type: " + reflect.TypeOf(arg).String())
|
||||||
|
}
|
||||||
|
binary.Write(buf, binary.BigEndian, int32(len(s)))
|
||||||
|
buf.WriteString(s)
|
||||||
|
}
|
||||||
|
// for _, pd := range ps.ParameterOids {
|
||||||
|
// transcoder := valueTranscoders[pd]
|
||||||
|
// if transcoder == nil {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
binary.Write(buf, binary.BigEndian, int16(0))
|
||||||
|
|
||||||
|
err = c.txMsg('B', buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// execute
|
||||||
|
buf = c.getBuf()
|
||||||
|
buf.WriteString("")
|
||||||
|
buf.WriteByte(0)
|
||||||
|
binary.Write(buf, binary.BigEndian, int32(0))
|
||||||
|
|
||||||
|
err = c.txMsg('E', buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// sync
|
||||||
|
err = c.txMsg('S', c.getBuf())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag string, err error) {
|
func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag string, err error) {
|
||||||
if err = c.sendSimpleQuery(sql, arguments...); err != nil {
|
if err = c.sendQuery(sql, arguments...); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -235,11 +418,13 @@ func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag s
|
||||||
var t byte
|
var t byte
|
||||||
var r *MessageReader
|
var r *MessageReader
|
||||||
if t, r, err = c.rxMsg(); err == nil {
|
if t, r, err = c.rxMsg(); err == nil {
|
||||||
|
// fmt.Printf("Execute received: %c\n", t)
|
||||||
switch t {
|
switch t {
|
||||||
case readyForQuery:
|
case readyForQuery:
|
||||||
return
|
return
|
||||||
case rowDescription:
|
case rowDescription:
|
||||||
case dataRow:
|
case dataRow:
|
||||||
|
case bindComplete:
|
||||||
case commandComplete:
|
case commandComplete:
|
||||||
commandTag = r.ReadString()
|
commandTag = r.ReadString()
|
||||||
default:
|
default:
|
||||||
|
@ -378,6 +563,15 @@ func (c *Connection) rxRowDescription(r *MessageReader) (fields []FieldDescripti
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Connection) rxParameterDescription(r *MessageReader) (parameters []oid) {
|
||||||
|
parameterCount := r.ReadInt16()
|
||||||
|
parameters = make([]oid, 0, parameterCount)
|
||||||
|
for i := int16(0); i < parameterCount; i++ {
|
||||||
|
parameters = append(parameters, r.ReadOid())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Connection) rxDataRow(r *DataRowReader) (row map[string]interface{}) {
|
func (c *Connection) rxDataRow(r *DataRowReader) (row map[string]interface{}) {
|
||||||
fieldCount := len(r.fields)
|
fieldCount := len(r.fields)
|
||||||
|
|
||||||
|
|
|
@ -285,3 +285,47 @@ func TestSelectValues(t *testing.T) {
|
||||||
t.Error("Multiple columns should have returned UnexpectedColumnCountError")
|
t.Error("Multiple columns should have returned UnexpectedColumnCountError")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPrepare(t *testing.T) {
|
||||||
|
conn, err := Connect(ConnectionParameters{Socket: "/private/tmp/.s.PGSQL.5432", User: "pgx_none", Database: "pgx_test"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("Unable to establish connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
testTranscode := func(sql string, value interface{}) {
|
||||||
|
if err = conn.Prepare("testTranscode", sql); err != nil {
|
||||||
|
t.Errorf("Unable to prepare statement: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := conn.Deallocate("testTranscode")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Deallocate failed: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var result interface{}
|
||||||
|
result, err = conn.SelectValue("testTranscode", value)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%v while running %v", err, "testTranscode")
|
||||||
|
} else {
|
||||||
|
if result != value {
|
||||||
|
t.Errorf("Expected: %#v Received: %#v", value, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test parameter encoding and decoding for simple supported data types
|
||||||
|
testTranscode("select $1::varchar", "foo")
|
||||||
|
testTranscode("select $1::text", "foo")
|
||||||
|
testTranscode("select $1::int2", int16(1))
|
||||||
|
testTranscode("select $1::int4", int32(1))
|
||||||
|
testTranscode("select $1::int8", int64(1))
|
||||||
|
testTranscode("select $1::float4", float32(1.23))
|
||||||
|
testTranscode("select $1::float8", float64(1.23))
|
||||||
|
|
||||||
|
// case []byte:
|
||||||
|
// s = `E'\\x` + hex.EncodeToString(arg) + `'`
|
||||||
|
|
||||||
|
}
|
||||||
|
|
|
@ -20,13 +20,20 @@ func newDataRowReader(mr *MessageReader, fields []FieldDescription) (r *DataRowR
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *DataRowReader) ReadValue() interface{} {
|
func (r *DataRowReader) ReadValue() interface{} {
|
||||||
dataType := r.fields[r.currentFieldIdx].DataType
|
fieldDescription := r.fields[r.currentFieldIdx]
|
||||||
r.currentFieldIdx++
|
r.currentFieldIdx++
|
||||||
|
|
||||||
size := r.mr.ReadInt32()
|
size := r.mr.ReadInt32()
|
||||||
if size > -1 {
|
if size > -1 {
|
||||||
if vt, present := valueTranscoders[dataType]; present {
|
if vt, present := valueTranscoders[fieldDescription.DataType]; present {
|
||||||
return vt.FromText(r.mr, size)
|
switch fieldDescription.FormatCode {
|
||||||
|
case 0:
|
||||||
|
return vt.DecodeText(r.mr, size)
|
||||||
|
case 1:
|
||||||
|
return vt.DecodeBinary(r.mr, size)
|
||||||
|
default:
|
||||||
|
panic("Unknown format")
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return r.mr.ReadByteString(size)
|
return r.mr.ReadByteString(size)
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,9 @@ const (
|
||||||
commandComplete = 'C'
|
commandComplete = 'C'
|
||||||
errorResponse = 'E'
|
errorResponse = 'E'
|
||||||
noticeResponse = 'N'
|
noticeResponse = 'N'
|
||||||
|
parseComplete = '1'
|
||||||
|
parameterDescription = 't'
|
||||||
|
bindComplete = '2'
|
||||||
)
|
)
|
||||||
|
|
||||||
type startupMessage struct {
|
type startupMessage struct {
|
||||||
|
|
|
@ -15,6 +15,11 @@ func (c *Connection) QuoteString(input string) (output string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Connection) QuoteIdentifier(input string) (output string) {
|
||||||
|
output = `"` + strings.Replace(input, `"`, `""`, -1) + `"`
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Connection) SanitizeSql(sql string, args ...interface{}) (output string) {
|
func (c *Connection) SanitizeSql(sql string, args ...interface{}) (output string) {
|
||||||
replacer := func(match string) (replacement string) {
|
replacer := func(match string) (replacement string) {
|
||||||
n, _ := strconv.ParseInt(match[1:], 10, 0)
|
n, _ := strconv.ParseInt(match[1:], 10, 0)
|
||||||
|
|
|
@ -2,14 +2,15 @@ package pgx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
type valueTranscoder struct {
|
type valueTranscoder struct {
|
||||||
FromText func(*MessageReader, int32) interface{}
|
DecodeText func(*MessageReader, int32) interface{}
|
||||||
// FromBinary func(*MessageReader, int32) interface{}
|
DecodeBinary func(*MessageReader, int32) interface{}
|
||||||
// ToText func(interface{}) string
|
EncodeTo func(io.Writer, interface{})
|
||||||
// ToBinary func(interface{}) []byte
|
EncodeFormat int16
|
||||||
}
|
}
|
||||||
|
|
||||||
var valueTranscoders map[oid]*valueTranscoder
|
var valueTranscoders map[oid]*valueTranscoder
|
||||||
|
@ -18,22 +19,22 @@ func init() {
|
||||||
valueTranscoders = make(map[oid]*valueTranscoder)
|
valueTranscoders = make(map[oid]*valueTranscoder)
|
||||||
|
|
||||||
// bool
|
// bool
|
||||||
valueTranscoders[oid(16)] = &valueTranscoder{FromText: decodeBoolFromText}
|
valueTranscoders[oid(16)] = &valueTranscoder{DecodeText: decodeBoolFromText}
|
||||||
|
|
||||||
// int8
|
// int8
|
||||||
valueTranscoders[oid(20)] = &valueTranscoder{FromText: decodeInt8FromText}
|
valueTranscoders[oid(20)] = &valueTranscoder{DecodeText: decodeInt8FromText}
|
||||||
|
|
||||||
// int2
|
// int2
|
||||||
valueTranscoders[oid(21)] = &valueTranscoder{FromText: decodeInt2FromText}
|
valueTranscoders[oid(21)] = &valueTranscoder{DecodeText: decodeInt2FromText}
|
||||||
|
|
||||||
// int4
|
// int4
|
||||||
valueTranscoders[oid(23)] = &valueTranscoder{FromText: decodeInt4FromText}
|
valueTranscoders[oid(23)] = &valueTranscoder{DecodeText: decodeInt4FromText}
|
||||||
|
|
||||||
// float4
|
// float4
|
||||||
valueTranscoders[oid(700)] = &valueTranscoder{FromText: decodeFloat4FromText}
|
valueTranscoders[oid(700)] = &valueTranscoder{DecodeText: decodeFloat4FromText}
|
||||||
|
|
||||||
// float8
|
// float8
|
||||||
valueTranscoders[oid(701)] = &valueTranscoder{FromText: decodeFloat8FromText}
|
valueTranscoders[oid(701)] = &valueTranscoder{DecodeText: decodeFloat8FromText}
|
||||||
}
|
}
|
||||||
|
|
||||||
func decodeBoolFromText(mr *MessageReader, size int32) interface{} {
|
func decodeBoolFromText(mr *MessageReader, size int32) interface{} {
|
||||||
|
|
Loading…
Reference in New Issue