mirror of https://github.com/jackc/pgx.git
Dirty, but somewhat working prepared statements and extended protocol
parent
1042f095ee
commit
5073a3b9e0
200
connection.go
200
connection.go
|
@ -9,6 +9,8 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type ConnectionParameters struct {
|
||||
|
@ -28,6 +30,13 @@ type Connection struct {
|
|||
runtimeParams map[string]string // parameters that have been reported by the server
|
||||
parameters ConnectionParameters // parameters used when establishing this connection
|
||||
txStatus byte
|
||||
preparedStatements map[string]*PreparedStatement
|
||||
}
|
||||
|
||||
type PreparedStatement struct {
|
||||
Name string
|
||||
FieldDescriptions []FieldDescription
|
||||
ParameterOids []oid
|
||||
}
|
||||
|
||||
type NotSingleRowError struct {
|
||||
|
@ -71,6 +80,7 @@ func Connect(parameters ConnectionParameters) (c *Connection, err error) {
|
|||
|
||||
c.buf = bytes.NewBuffer(make([]byte, sharedBufferSize))
|
||||
c.runtimeParams = make(map[string]string)
|
||||
c.preparedStatements = make(map[string]*PreparedStatement)
|
||||
|
||||
msg := newStartupMessage()
|
||||
msg.options["user"] = c.parameters.User
|
||||
|
@ -108,12 +118,19 @@ func (c *Connection) Close() (err error) {
|
|||
}
|
||||
|
||||
func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error, arguments ...interface{}) (err error) {
|
||||
if err = c.sendSimpleQuery(sql, arguments...); err != nil {
|
||||
var fields []FieldDescription
|
||||
|
||||
if ps, present := c.preparedStatements[sql]; present {
|
||||
fields = ps.FieldDescriptions
|
||||
err = c.sendPreparedQuery(ps, arguments...)
|
||||
} else {
|
||||
err = c.sendSimpleQuery(sql, arguments...)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var callbackError error
|
||||
var fields []FieldDescription
|
||||
|
||||
for {
|
||||
var t byte
|
||||
|
@ -132,6 +149,7 @@ func (c *Connection) SelectFunc(sql string, onDataRow func(*DataRowReader) error
|
|||
callbackError = onDataRow(newDataRowReader(r, fields))
|
||||
}
|
||||
case commandComplete:
|
||||
case bindComplete:
|
||||
default:
|
||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||
return
|
||||
|
@ -207,6 +225,101 @@ func (c *Connection) SelectValues(sql string, arguments ...interface{}) (values
|
|||
return
|
||||
}
|
||||
|
||||
func (c *Connection) Prepare(name, sql string) (err error) {
|
||||
// parse
|
||||
buf := c.getBuf()
|
||||
_, err = buf.WriteString(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = buf.WriteByte(0)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buf.WriteString(sql)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = buf.WriteByte(0)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = binary.Write(buf, binary.BigEndian, int16(0))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = c.txMsg('P', buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// describe
|
||||
buf = c.getBuf()
|
||||
err = buf.WriteByte('S')
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = buf.WriteString(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = buf.WriteByte(0)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = c.txMsg('D', buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// sync
|
||||
err = c.txMsg('S', c.getBuf())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ps := PreparedStatement{Name: name}
|
||||
|
||||
for {
|
||||
var t byte
|
||||
var r *MessageReader
|
||||
if t, r, err = c.rxMsg(); err == nil {
|
||||
switch t {
|
||||
case parseComplete:
|
||||
case parameterDescription:
|
||||
ps.ParameterOids = c.rxParameterDescription(r)
|
||||
case rowDescription:
|
||||
ps.FieldDescriptions = c.rxRowDescription(r)
|
||||
case readyForQuery:
|
||||
c.preparedStatements[name] = &ps
|
||||
return
|
||||
default:
|
||||
if err = c.processContextFreeMsg(t, r); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) Deallocate(name string) (err error) {
|
||||
delete(c.preparedStatements, name)
|
||||
_, err = c.Execute("deallocate " + c.QuoteIdentifier(name))
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) sendQuery(sql string, arguments ...interface{}) (err error) {
|
||||
if ps, present := c.preparedStatements[sql]; present {
|
||||
return c.sendPreparedQuery(ps, arguments...)
|
||||
} else {
|
||||
return c.sendSimpleQuery(sql, arguments...)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err error) {
|
||||
if len(arguments) > 0 {
|
||||
sql = c.SanitizeSql(sql, arguments...)
|
||||
|
@ -226,8 +339,78 @@ func (c *Connection) sendSimpleQuery(sql string, arguments ...interface{}) (err
|
|||
return c.txMsg('Q', buf)
|
||||
}
|
||||
|
||||
func (c *Connection) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) {
|
||||
if len(ps.ParameterOids) != len(arguments) {
|
||||
return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments))
|
||||
}
|
||||
|
||||
// bind
|
||||
buf := c.getBuf()
|
||||
buf.WriteString("")
|
||||
buf.WriteByte(0)
|
||||
buf.WriteString(ps.Name)
|
||||
buf.WriteByte(0)
|
||||
binary.Write(buf, binary.BigEndian, int16(0))
|
||||
binary.Write(buf, binary.BigEndian, int16(len(arguments)))
|
||||
for _, iArg := range arguments {
|
||||
var s string
|
||||
switch arg := iArg.(type) {
|
||||
case string:
|
||||
s = arg
|
||||
case int16:
|
||||
s = strconv.FormatInt(int64(arg), 10)
|
||||
case int32:
|
||||
s = strconv.FormatInt(int64(arg), 10)
|
||||
case int64:
|
||||
s = strconv.FormatInt(int64(arg), 10)
|
||||
case float32:
|
||||
s = strconv.FormatFloat(float64(arg), 'f', -1, 32)
|
||||
case float64:
|
||||
s = strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
case []byte:
|
||||
s = `E'\\x` + hex.EncodeToString(arg) + `'`
|
||||
default:
|
||||
panic("Unable to encode type: " + reflect.TypeOf(arg).String())
|
||||
}
|
||||
binary.Write(buf, binary.BigEndian, int32(len(s)))
|
||||
buf.WriteString(s)
|
||||
}
|
||||
// for _, pd := range ps.ParameterOids {
|
||||
// transcoder := valueTranscoders[pd]
|
||||
// if transcoder == nil {
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
binary.Write(buf, binary.BigEndian, int16(0))
|
||||
|
||||
err = c.txMsg('B', buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// execute
|
||||
buf = c.getBuf()
|
||||
buf.WriteString("")
|
||||
buf.WriteByte(0)
|
||||
binary.Write(buf, binary.BigEndian, int32(0))
|
||||
|
||||
err = c.txMsg('E', buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// sync
|
||||
err = c.txMsg('S', c.getBuf())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
}
|
||||
|
||||
func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag string, err error) {
|
||||
if err = c.sendSimpleQuery(sql, arguments...); err != nil {
|
||||
if err = c.sendQuery(sql, arguments...); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -235,11 +418,13 @@ func (c *Connection) Execute(sql string, arguments ...interface{}) (commandTag s
|
|||
var t byte
|
||||
var r *MessageReader
|
||||
if t, r, err = c.rxMsg(); err == nil {
|
||||
// fmt.Printf("Execute received: %c\n", t)
|
||||
switch t {
|
||||
case readyForQuery:
|
||||
return
|
||||
case rowDescription:
|
||||
case dataRow:
|
||||
case bindComplete:
|
||||
case commandComplete:
|
||||
commandTag = r.ReadString()
|
||||
default:
|
||||
|
@ -378,6 +563,15 @@ func (c *Connection) rxRowDescription(r *MessageReader) (fields []FieldDescripti
|
|||
return
|
||||
}
|
||||
|
||||
func (c *Connection) rxParameterDescription(r *MessageReader) (parameters []oid) {
|
||||
parameterCount := r.ReadInt16()
|
||||
parameters = make([]oid, 0, parameterCount)
|
||||
for i := int16(0); i < parameterCount; i++ {
|
||||
parameters = append(parameters, r.ReadOid())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) rxDataRow(r *DataRowReader) (row map[string]interface{}) {
|
||||
fieldCount := len(r.fields)
|
||||
|
||||
|
|
|
@ -285,3 +285,47 @@ func TestSelectValues(t *testing.T) {
|
|||
t.Error("Multiple columns should have returned UnexpectedColumnCountError")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepare(t *testing.T) {
|
||||
conn, err := Connect(ConnectionParameters{Socket: "/private/tmp/.s.PGSQL.5432", User: "pgx_none", Database: "pgx_test"})
|
||||
if err != nil {
|
||||
t.Fatal("Unable to establish connection")
|
||||
}
|
||||
|
||||
testTranscode := func(sql string, value interface{}) {
|
||||
if err = conn.Prepare("testTranscode", sql); err != nil {
|
||||
t.Errorf("Unable to prepare statement: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
err := conn.Deallocate("testTranscode")
|
||||
if err != nil {
|
||||
t.Errorf("Deallocate failed: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var result interface{}
|
||||
result, err = conn.SelectValue("testTranscode", value)
|
||||
if err != nil {
|
||||
t.Errorf("%v while running %v", err, "testTranscode")
|
||||
} else {
|
||||
if result != value {
|
||||
t.Errorf("Expected: %#v Received: %#v", value, result)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Test parameter encoding and decoding for simple supported data types
|
||||
testTranscode("select $1::varchar", "foo")
|
||||
testTranscode("select $1::text", "foo")
|
||||
testTranscode("select $1::int2", int16(1))
|
||||
testTranscode("select $1::int4", int32(1))
|
||||
testTranscode("select $1::int8", int64(1))
|
||||
testTranscode("select $1::float4", float32(1.23))
|
||||
testTranscode("select $1::float8", float64(1.23))
|
||||
|
||||
// case []byte:
|
||||
// s = `E'\\x` + hex.EncodeToString(arg) + `'`
|
||||
|
||||
}
|
||||
|
|
|
@ -20,13 +20,20 @@ func newDataRowReader(mr *MessageReader, fields []FieldDescription) (r *DataRowR
|
|||
}
|
||||
|
||||
func (r *DataRowReader) ReadValue() interface{} {
|
||||
dataType := r.fields[r.currentFieldIdx].DataType
|
||||
fieldDescription := r.fields[r.currentFieldIdx]
|
||||
r.currentFieldIdx++
|
||||
|
||||
size := r.mr.ReadInt32()
|
||||
if size > -1 {
|
||||
if vt, present := valueTranscoders[dataType]; present {
|
||||
return vt.FromText(r.mr, size)
|
||||
if vt, present := valueTranscoders[fieldDescription.DataType]; present {
|
||||
switch fieldDescription.FormatCode {
|
||||
case 0:
|
||||
return vt.DecodeText(r.mr, size)
|
||||
case 1:
|
||||
return vt.DecodeBinary(r.mr, size)
|
||||
default:
|
||||
panic("Unknown format")
|
||||
}
|
||||
} else {
|
||||
return r.mr.ReadByteString(size)
|
||||
}
|
||||
|
|
|
@ -17,6 +17,9 @@ const (
|
|||
commandComplete = 'C'
|
||||
errorResponse = 'E'
|
||||
noticeResponse = 'N'
|
||||
parseComplete = '1'
|
||||
parameterDescription = 't'
|
||||
bindComplete = '2'
|
||||
)
|
||||
|
||||
type startupMessage struct {
|
||||
|
|
|
@ -15,6 +15,11 @@ func (c *Connection) QuoteString(input string) (output string) {
|
|||
return
|
||||
}
|
||||
|
||||
func (c *Connection) QuoteIdentifier(input string) (output string) {
|
||||
output = `"` + strings.Replace(input, `"`, `""`, -1) + `"`
|
||||
return
|
||||
}
|
||||
|
||||
func (c *Connection) SanitizeSql(sql string, args ...interface{}) (output string) {
|
||||
replacer := func(match string) (replacement string) {
|
||||
n, _ := strconv.ParseInt(match[1:], 10, 0)
|
||||
|
|
|
@ -2,14 +2,15 @@ package pgx
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type valueTranscoder struct {
|
||||
FromText func(*MessageReader, int32) interface{}
|
||||
// FromBinary func(*MessageReader, int32) interface{}
|
||||
// ToText func(interface{}) string
|
||||
// ToBinary func(interface{}) []byte
|
||||
DecodeText func(*MessageReader, int32) interface{}
|
||||
DecodeBinary func(*MessageReader, int32) interface{}
|
||||
EncodeTo func(io.Writer, interface{})
|
||||
EncodeFormat int16
|
||||
}
|
||||
|
||||
var valueTranscoders map[oid]*valueTranscoder
|
||||
|
@ -18,22 +19,22 @@ func init() {
|
|||
valueTranscoders = make(map[oid]*valueTranscoder)
|
||||
|
||||
// bool
|
||||
valueTranscoders[oid(16)] = &valueTranscoder{FromText: decodeBoolFromText}
|
||||
valueTranscoders[oid(16)] = &valueTranscoder{DecodeText: decodeBoolFromText}
|
||||
|
||||
// int8
|
||||
valueTranscoders[oid(20)] = &valueTranscoder{FromText: decodeInt8FromText}
|
||||
valueTranscoders[oid(20)] = &valueTranscoder{DecodeText: decodeInt8FromText}
|
||||
|
||||
// int2
|
||||
valueTranscoders[oid(21)] = &valueTranscoder{FromText: decodeInt2FromText}
|
||||
valueTranscoders[oid(21)] = &valueTranscoder{DecodeText: decodeInt2FromText}
|
||||
|
||||
// int4
|
||||
valueTranscoders[oid(23)] = &valueTranscoder{FromText: decodeInt4FromText}
|
||||
valueTranscoders[oid(23)] = &valueTranscoder{DecodeText: decodeInt4FromText}
|
||||
|
||||
// float4
|
||||
valueTranscoders[oid(700)] = &valueTranscoder{FromText: decodeFloat4FromText}
|
||||
valueTranscoders[oid(700)] = &valueTranscoder{DecodeText: decodeFloat4FromText}
|
||||
|
||||
// float8
|
||||
valueTranscoders[oid(701)] = &valueTranscoder{FromText: decodeFloat8FromText}
|
||||
valueTranscoders[oid(701)] = &valueTranscoder{DecodeText: decodeFloat8FromText}
|
||||
}
|
||||
|
||||
func decodeBoolFromText(mr *MessageReader, size int32) interface{} {
|
||||
|
|
Loading…
Reference in New Issue