Dirty, but somewhat working prepared statements and extended protocol

pgx-vs-pq
Jack Christensen 2013-07-01 15:41:20 -05:00
parent 1042f095ee
commit 5073a3b9e0
6 changed files with 285 additions and 31 deletions

View File

@ -9,6 +9,8 @@ import (
"fmt"
"io"
"net"
"reflect"
"strconv"
)
type ConnectionParameters struct {
@ -28,6 +30,13 @@ type Connection struct {
runtimeParams map[string]string // parameters that have been reported by the server
parameters ConnectionParameters // parameters used when establishing this connection
txStatus byte
preparedStatements map[string]*PreparedStatement
}
type PreparedStatement struct {
Name string
FieldDescriptions []FieldDescription
ParameterOids []oid
}
type NotSingleRowError struct {
@ -71,6 +80,7 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) {
c.buf = bytes.NewBuffer(make([]byte, sharedBufferSize))
c.runtimeParams = make(map[string]string)
c.preparedStatements = make(map[string]*PreparedStatement)
msg := newStartupMessage()
msg.options["user"] = c.parameters.User
@ -108,12 +118,19 @@ func (c *Connection) Close() (err error) {
}
func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error, arguments ...interface{}) (err error) {
if err = c.sendSimpleQuery(sql, arguments...); err != nil {
var fields []FieldDescription
if ps, present := c.preparedStatements[sql]; present {
fields = ps.FieldDescriptions
err = c.sendPreparedQuery(ps, arguments...)
} else {
err = c.sendSimpleQuery(sql, arguments...)
}
if err != nil {
return
}
var callbackError error
var fields []FieldDescription
for {
var t byte
@ -132,6 +149,7 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error
callbackError = onDataRow(newDataRowReader(r, fields))
}
case commandComplete:
case bindComplete:
default:
if err = c.processContextFreeMsg(t, r); err != nil {
return
@ -207,6 +225,101 @@ func (c *Connection) SelectValues(sql string, arguments ...interface{}) (values
return
}
func (c *Connection) Prepare(name, sql string) (err error) {
// parse
buf := c.getBuf()
_, err = buf.WriteString(name)
if err != nil {
return
}
err = buf.WriteByte(0)
if err != nil {
return
}
_, err = buf.WriteString(sql)
if err != nil {
return
}
err = buf.WriteByte(0)
if err != nil {
return
}
err = binary.Write(buf, binary.BigEndian, int16(0))
if err != nil {
return
}
err = c.txMsg('P', buf)
if err != nil {
return
}
// describe
buf = c.getBuf()
err = buf.WriteByte('S')
if err != nil {
return
}
_, err = buf.WriteString(name)
if err != nil {
return
}
err = buf.WriteByte(0)
if err != nil {
return
}
err = c.txMsg('D', buf)
if err != nil {
return
}
// sync
err = c.txMsg('S', c.getBuf())
if err != nil {
return err
}
ps := PreparedStatement{Name: name}
for {
var t byte
var r *MessageReader
if t, r, err = c.rxMsg(); err == nil {
switch t {
case parseComplete:
case parameterDescription:
ps.ParameterOids = c.rxParameterDescription(r)
case rowDescription:
ps.FieldDescriptions = c.rxRowDescription(r)
case readyForQuery:
c.preparedStatements[name] = &ps
return
default:
if err = c.processContextFreeMsg(t, r); err != nil {
return
}
}
} else {
return
}
}
}
func (c *Connection) Deallocate(name string) (err error) {
delete(c.preparedStatements, name)
_, err = c.Execute("deallocate " + c.QuoteIdentifier(name))
return
}
func (c *Connection) sendQuery(sql string, arguments ...interface{}) (err error) {
if ps, present := c.preparedStatements[sql]; present {
return c.sendPreparedQuery(ps, arguments...)
} else {
return c.sendSimpleQuery(sql, arguments...)
}
}
func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err error) {
if len(arguments) > 0 {
sql = c.SanitizeSql(sql, arguments...)
@ -226,8 +339,78 @@ func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err
return c.txMsg('Q', buf)
}
func (c *Connection) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) {
if len(ps.ParameterOids) != len(arguments) {
return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments))
}
// bind
buf := c.getBuf()
buf.WriteString("")
buf.WriteByte(0)
buf.WriteString(ps.Name)
buf.WriteByte(0)
binary.Write(buf, binary.BigEndian, int16(0))
binary.Write(buf, binary.BigEndian, int16(len(arguments)))
for _, iArg := range arguments {
var s string
switch arg := iArg.(type) {
case string:
s = arg
case int16:
s = strconv.FormatInt(int64(arg), 10)
case int32:
s = strconv.FormatInt(int64(arg), 10)
case int64:
s = strconv.FormatInt(int64(arg), 10)
case float32:
s = strconv.FormatFloat(float64(arg), 'f', -1, 32)
case float64:
s = strconv.FormatFloat(arg, 'f', -1, 64)
case []byte:
s = `E'\\x` + hex.EncodeToString(arg) + `'`
default:
panic("Unable to encode type: " + reflect.TypeOf(arg).String())
}
binary.Write(buf, binary.BigEndian, int32(len(s)))
buf.WriteString(s)
}
// for _, pd := range ps.ParameterOids {
// transcoder := valueTranscoders[pd]
// if transcoder == nil {
// return
// }
// }
binary.Write(buf, binary.BigEndian, int16(0))
err = c.txMsg('B', buf)
if err != nil {
return err
}
// execute
buf = c.getBuf()
buf.WriteString("")
buf.WriteByte(0)
binary.Write(buf, binary.BigEndian, int32(0))
err = c.txMsg('E', buf)
if err != nil {
return err
}
// sync
err = c.txMsg('S', c.getBuf())
if err != nil {
return err
}
return
}
func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag string, err error) {
if err = c.sendSimpleQuery(sql, arguments...); err != nil {
if err = c.sendQuery(sql, arguments...); err != nil {
return
}
@ -235,11 +418,13 @@ func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag s
var t byte
var r *MessageReader
if t, r, err = c.rxMsg(); err == nil {
// fmt.Printf("Execute received: %c\n", t)
switch t {
case readyForQuery:
return
case rowDescription:
case dataRow:
case bindComplete:
case commandComplete:
commandTag = r.ReadString()
default:
@ -378,6 +563,15 @@ func (c *Connection) rxRowDescription(r *MessageReader) (fields []FieldDescripti
return
}
func (c *Connection) rxParameterDescription(r *MessageReader) (parameters []oid) {
parameterCount := r.ReadInt16()
parameters = make([]oid, 0, parameterCount)
for i := int16(0); i < parameterCount; i++ {
parameters = append(parameters, r.ReadOid())
}
return
}
func (c *Connection) rxDataRow(r *DataRowReader) (row map[string]interface{}) {
fieldCount := len(r.fields)

View File

@ -285,3 +285,47 @@ func TestSelectValues(t *testing.T) {
t.Error("Multiple columns should have returned UnexpectedColumnCountError")
}
}
func TestPrepare(t *testing.T) {
conn, err := Connect(ConnectionParameters{Socket: "/private/tmp/.s.PGSQL.5432", User: "pgx_none", Database: "pgx_test"})
if err != nil {
t.Fatal("Unable to establish connection")
}
testTranscode := func(sql string, value interface{}) {
if err = conn.Prepare("testTranscode", sql); err != nil {
t.Errorf("Unable to prepare statement: %v", err)
return
}
defer func() {
err := conn.Deallocate("testTranscode")
if err != nil {
t.Errorf("Deallocate failed: %v", err)
}
}()
var result interface{}
result, err = conn.SelectValue("testTranscode", value)
if err != nil {
t.Errorf("%v while running %v", err, "testTranscode")
} else {
if result != value {
t.Errorf("Expected: %#v Received: %#v", value, result)
}
}
}
// Test parameter encoding and decoding for simple supported data types
testTranscode("select $1::varchar", "foo")
testTranscode("select $1::text", "foo")
testTranscode("select $1::int2", int16(1))
testTranscode("select $1::int4", int32(1))
testTranscode("select $1::int8", int64(1))
testTranscode("select $1::float4", float32(1.23))
testTranscode("select $1::float8", float64(1.23))
// case []byte:
// s = `E'\\x` + hex.EncodeToString(arg) + `'`
}

View File

@ -20,13 +20,20 @@ func newDataRowReader(mr *MessageReader, fields []FieldDescription) (r *DataRowR
}
func (r *DataRowReader) ReadValue() interface{} {
dataType := r.fields[r.currentFieldIdx].DataType
fieldDescription := r.fields[r.currentFieldIdx]
r.currentFieldIdx++
size := r.mr.ReadInt32()
if size > -1 {
if vt, present := valueTranscoders[dataType]; present {
return vt.FromText(r.mr, size)
if vt, present := valueTranscoders[fieldDescription.DataType]; present {
switch fieldDescription.FormatCode {
case 0:
return vt.DecodeText(r.mr, size)
case 1:
return vt.DecodeBinary(r.mr, size)
default:
panic("Unknown format")
}
} else {
return r.mr.ReadByteString(size)
}

View File

@ -17,6 +17,9 @@ const (
commandComplete = 'C'
errorResponse = 'E'
noticeResponse = 'N'
parseComplete = '1'
parameterDescription = 't'
bindComplete = '2'
)
type startupMessage struct {

View File

@ -15,6 +15,11 @@ func (c *Connection) QuoteString(input string) (output string) {
return
}
func (c *Connection) QuoteIdentifier(input string) (output string) {
output = `"` + strings.Replace(input, `"`, `""`, -1) + `"`
return
}
func (c *Connection) SanitizeSql(sql string, args ...interface{}) (output string) {
replacer := func(match string) (replacement string) {
n, _ := strconv.ParseInt(match[1:], 10, 0)

View File

@ -2,14 +2,15 @@ package pgx
import (
"fmt"
"io"
"strconv"
)
type valueTranscoder struct {
FromText func(*MessageReader, int32) interface{}
// FromBinary func(*MessageReader, int32) interface{}
// ToText func(interface{}) string
// ToBinary func(interface{}) []byte
DecodeText func(*MessageReader, int32) interface{}
DecodeBinary func(*MessageReader, int32) interface{}
EncodeTo func(io.Writer, interface{})
EncodeFormat int16
}
var valueTranscoders map[oid]*valueTranscoder
@ -18,22 +19,22 @@ func init() {
valueTranscoders = make(map[oid]*valueTranscoder)
// bool
valueTranscoders[oid(16)] = &valueTranscoder{FromText: decodeBoolFromText}
valueTranscoders[oid(16)] = &valueTranscoder{DecodeText: decodeBoolFromText}
// int8
valueTranscoders[oid(20)] = &valueTranscoder{FromText: decodeInt8FromText}
valueTranscoders[oid(20)] = &valueTranscoder{DecodeText: decodeInt8FromText}
// int2
valueTranscoders[oid(21)] = &valueTranscoder{FromText: decodeInt2FromText}
valueTranscoders[oid(21)] = &valueTranscoder{DecodeText: decodeInt2FromText}
// int4
valueTranscoders[oid(23)] = &valueTranscoder{FromText: decodeInt4FromText}
valueTranscoders[oid(23)] = &valueTranscoder{DecodeText: decodeInt4FromText}
// float4
valueTranscoders[oid(700)] = &valueTranscoder{FromText: decodeFloat4FromText}
valueTranscoders[oid(700)] = &valueTranscoder{DecodeText: decodeFloat4FromText}
// float8
valueTranscoders[oid(701)] = &valueTranscoder{FromText: decodeFloat8FromText}
valueTranscoders[oid(701)] = &valueTranscoder{DecodeText: decodeFloat8FromText}
}
func decodeBoolFromText(mr *MessageReader, size int32) interface{} {