Add simple protocol suuport with (Query|Exec)Ex

batch-wip
Jack Christensen 2017-04-10 08:58:51 -05:00
parent 54d9cbc743
commit 7b1f461ec3
16 changed files with 999 additions and 326 deletions

66
conn.go
View File

@ -1021,7 +1021,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
// arguments should be referenced positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
return c.ExecContext(context.Background(), sql, arguments...)
return c.ExecEx(context.Background(), sql, nil, arguments...)
}
// Processes messages that are not exclusive to one context such as
@ -1364,24 +1364,16 @@ func (c *Conn) Ping() error {
}
func (c *Conn) PingContext(ctx context.Context) error {
_, err := c.ExecContext(ctx, ";")
_, err := c.ExecEx(ctx, ";", nil)
return err
}
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return "", err
}
err = c.initContext(ctx)
if err != nil {
return "", err
}
defer func() {
err = c.termContext(err)
}()
if err = c.lock(); err != nil {
return commandTag, err
}
@ -1406,8 +1398,56 @@ func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interfa
}
}()
if err = c.sendQuery(sql, arguments...); err != nil {
return
if options != nil && options.SimpleProtocol {
err = c.initContext(ctx)
if err != nil {
return "", err
}
defer func() {
err = c.termContext(err)
}()
err = c.sanitizeAndSendSimpleQuery(sql, arguments...)
if err != nil {
return "", err
}
} else {
if len(arguments) > 0 {
ps, ok := c.preparedStatements[sql]
if !ok {
var err error
ps, err = c.PrepareExContext(ctx, "", sql, nil)
if err != nil {
return "", err
}
}
err = c.initContext(ctx)
if err != nil {
return "", err
}
defer func() {
err = c.termContext(err)
}()
err = c.sendPreparedQuery(ps, arguments...)
if err != nil {
return "", err
}
} else {
err = c.initContext(ctx)
if err != nil {
return "", err
}
defer func() {
err = c.termContext(err)
}()
if err = c.sendQuery(sql, arguments...); err != nil {
return
}
}
}
var softErr error

View File

@ -360,14 +360,14 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman
return c.Exec(sql, arguments...)
}
func (p *ConnPool) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
var c *Conn
if c, err = p.Acquire(); err != nil {
return
}
defer p.Release(c)
return c.ExecContext(ctx, sql, arguments...)
return c.ExecEx(ctx, sql, options, arguments...)
}
// Query acquires a connection and delegates the call to that connection. When
@ -390,14 +390,14 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
return rows, nil
}
func (p *ConnPool) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) {
func (p *ConnPool) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) {
c, err := p.Acquire()
if err != nil {
// Because checking for errors can be deferred to the *Rows, build one with the error
return &Rows{closed: true, err: err}, err
}
rows, err := c.QueryContext(ctx, sql, args...)
rows, err := c.QueryEx(ctx, sql, options, args...)
if err != nil {
p.Release(c)
return rows, err
@ -416,8 +416,8 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row {
return (*Row)(rows)
}
func (p *ConnPool) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row {
rows, _ := p.QueryContext(ctx, sql, args...)
func (p *ConnPool) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row {
rows, _ := p.QueryEx(ctx, sql, options, args...)
return (*Row)(rows)
}

View File

@ -1033,7 +1033,7 @@ func TestExecFailure(t *testing.T) {
}
}
func TestExecContextWithoutCancelation(t *testing.T) {
func TestExecExContextWithoutCancelation(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -1042,16 +1042,16 @@ func TestExecContextWithoutCancelation(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
commandTag, err := conn.ExecContext(ctx, "create temporary table foo(id integer primary key);")
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(id integer primary key);", nil)
if err != nil {
t.Fatal(err)
}
if commandTag != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecContext: %v", commandTag)
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
}
}
func TestExecContextFailureWithoutCancelation(t *testing.T) {
func TestExecExContextFailureWithoutCancelation(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -1060,18 +1060,18 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
if _, err := conn.ExecContext(ctx, "selct;"); err == nil {
if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil {
t.Fatal("Expected SQL syntax error")
}
rows, _ := conn.Query("select 1")
rows.Close()
if rows.Err() != nil {
t.Fatalf("ExecContext failure appears to have broken connection: %v", rows.Err())
t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err())
}
}
func TestExecContextCancelationCancelsQuery(t *testing.T) {
func TestExecExContextCancelationCancelsQuery(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -1083,7 +1083,7 @@ func TestExecContextCancelationCancelsQuery(t *testing.T) {
cancelFunc()
}()
_, err := conn.ExecContext(ctx, "select pg_sleep(60)")
_, err := conn.ExecEx(ctx, "select pg_sleep(60)", nil)
if err != context.Canceled {
t.Fatalf("Expected context.Canceled err, got %v", err)
}
@ -1091,6 +1091,70 @@ func TestExecContextCancelationCancelsQuery(t *testing.T) {
ensureConnValid(t, conn)
}
func TestExecExExtendedProtocol(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
if err != nil {
t.Fatal(err)
}
if commandTag != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
}
commandTag, err = conn.ExecEx(
ctx,
"insert into foo(name) values($1);",
nil,
"bar",
)
if err != nil {
t.Fatal(err)
}
if commandTag != "INSERT 0 1" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
}
ensureConnValid(t, conn)
}
func TestExecExSimpleProtocol(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
if err != nil {
t.Fatal(err)
}
if commandTag != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
}
commandTag, err = conn.ExecEx(
ctx,
"insert into foo(name) values($1);",
&pgx.QueryExOptions{SimpleProtocol: true},
"bar'; drop table foo;--",
)
if err != nil {
t.Fatal(err)
}
if commandTag != "INSERT 0 1" {
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
}
}
func TestPrepare(t *testing.T) {
t.Parallel()

View File

@ -0,0 +1,236 @@
package sanitize
import (
"bytes"
"encoding/hex"
"fmt"
"strconv"
"strings"
"time"
"unicode/utf8"
)
// Part is either a string or an int. A string is raw SQL. An int is a
// argument placeholder.
type Part interface{}
type Query struct {
Parts []Part
}
func (q *Query) Sanitize(args ...interface{}) (string, error) {
argUse := make([]bool, len(args))
buf := &bytes.Buffer{}
for _, part := range q.Parts {
var str string
switch part := part.(type) {
case string:
str = part
case int:
argIdx := part - 1
if argIdx >= len(args) {
return "", fmt.Errorf("insufficient arguments")
}
arg := args[argIdx]
switch arg := arg.(type) {
case nil:
str = "null"
case int64:
str = strconv.FormatInt(arg, 10)
case float64:
str = strconv.FormatFloat(arg, 'f', -1, 64)
case bool:
str = strconv.FormatBool(arg)
case []byte:
str = QuoteBytes(arg)
case string:
str = QuoteString(arg)
case time.Time:
str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", fmt.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true
default:
return "", fmt.Errorf("invalid Part type: %T", part)
}
buf.WriteString(str)
}
for i, used := range argUse {
if !used {
return "", fmt.Errorf("unused argument: %d", i)
}
}
return buf.String(), nil
}
func NewQuery(sql string) (*Query, error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
}
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
query := &Query{Parts: l.parts}
return query, nil
}
func QuoteString(str string) string {
return "'" + strings.Replace(str, "'", "''", -1) + "'"
}
func QuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
}
type sqlLexer struct {
src string
start int
pos int
stateFn stateFn
parts []Part
}
type stateFn func(*sqlLexer) stateFn
func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case 'e', 'E':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '\'' {
l.pos += width
return escapeStringState
}
case '\'':
return singleQuoteState
case '"':
return doubleQuoteState
case '$':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if '0' <= nextRune && nextRune <= '9' {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}
l.start = l.pos
return placeholderState
}
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func singleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func doubleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '"':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '"' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
// placeholderState consumes a placeholder value. The $ must have already has
// already been consumed. The first rune must be a digit.
func placeholderState(l *sqlLexer) stateFn {
num := 0
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
if '0' <= r && r <= '9' {
num *= 10
num += int(r - '0')
} else {
l.parts = append(l.parts, num)
l.pos -= width
l.start = l.pos
return rawState
}
}
}
func escapeStringState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func SanitizeSQL(sql string, args ...interface{}) (string, error) {
query, err := NewQuery(sql)
if err != nil {
return "", err
}
return query.Sanitize(args...)
}

View File

@ -0,0 +1,175 @@
package sanitize_test
import (
"testing"
"github.com/jackc/pgx/internal/sanitize"
)
func TestNewQuery(t *testing.T) {
successTests := []struct {
sql string
expected sanitize.Query
}{
{
sql: "select 42",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
},
{
sql: "select $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
},
{
sql: "select 'quoted $42', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}},
},
{
sql: `select "doubled quoted $42", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}},
},
{
sql: "select 'foo''bar', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}},
},
{
sql: `select "foo""bar", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}},
},
{
sql: "select '''', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}},
},
{
sql: `select """", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}},
},
{
sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11",
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}},
},
{
sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}},
},
{
sql: `select E'escape string\' $42', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}},
},
{
sql: `select e'escape string\' $42', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
},
}
for i, tt := range successTests {
query, err := sanitize.NewQuery(tt.sql)
if err != nil {
t.Errorf("%d. %v", i, err)
}
if len(query.Parts) == len(tt.expected.Parts) {
for j := range query.Parts {
if query.Parts[j] != tt.expected.Parts[j] {
t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j])
}
}
} else {
t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts)
}
}
}
func TestQuerySanitize(t *testing.T) {
successfulTests := []struct {
query sanitize.Query
args []interface{}
expected string
}{
{
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
args: []interface{}{},
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{int64(42)},
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{float64(1.23)},
expected: `select 1.23`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{true},
expected: `select true`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{[]byte{0, 1, 2, 3, 255}},
expected: `select '\x00010203ff'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{nil},
expected: `select null`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{"foobar"},
expected: `select 'foobar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{"foo'bar"},
expected: `select 'foo''bar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{`foo\'bar`},
expected: `select 'foo\''bar'`,
},
}
for i, tt := range successfulTests {
actual, err := tt.query.Sanitize(tt.args...)
if err != nil {
t.Errorf("%d. %v", i, err)
continue
}
if tt.expected != actual {
t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual)
}
}
errorTests := []struct {
query sanitize.Query
args []interface{}
expected string
}{
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
args: []interface{}{int64(42)},
expected: `insufficient arguments`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}},
args: []interface{}{int64(42)},
expected: `unused argument: 0`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{42},
expected: `invalid arg type: int`,
},
}
for i, tt := range errorTests {
_, err := tt.query.Sanitize(tt.args...)
if err == nil || err.Error() != tt.expected {
t.Errorf("%d. expected error %v, got %v", i, tt.expected, err)
}
}
}

View File

@ -8,10 +8,23 @@ import (
)
func TestCidTranscode(t *testing.T) {
testSuccessfulTranscode(t, "cid", []interface{}{
pgTypeName := "cid"
values := []interface{}{
pgtype.Cid{Uint: 42, Status: pgtype.Present},
pgtype.Cid{Status: pgtype.Null},
})
}
eqFunc := func(a, b interface{}) bool {
return reflect.DeepEqual(a, b)
}
testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
// No direct conversion from int to cid, convert through text
testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc)
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
}
}
func TestCidSet(t *testing.T) {

View File

@ -145,7 +145,7 @@ func (dst *Json) Scan(src interface{}) error {
func (src Json) Value() (driver.Value, error) {
switch src.Status {
case Present:
return src.Bytes, nil
return string(src.Bytes), nil
case Null:
return nil, nil
default:

View File

@ -121,13 +121,13 @@ func (src *Numeric) AssignTo(dst interface{}) error {
case Present:
switch v := dst.(type) {
case *float32:
f, err := strconv.ParseFloat(src.Int.String(), 64)
f, err := src.toFloat64()
if err != nil {
return err
}
return float64AssignTo(f, src.Status, dst)
case *float64:
f, err := strconv.ParseFloat(src.Int.String(), 64)
f, err := src.toFloat64()
if err != nil {
return err
}
@ -283,6 +283,23 @@ func (dst *Numeric) toBigInt() (*big.Int, error) {
return num, nil
}
func (src *Numeric) toFloat64() (float64, error) {
f, err := strconv.ParseFloat(src.Int.String(), 64)
if err != nil {
return 0, err
}
if src.Exp > 0 {
for i := 0; i < int(src.Exp); i++ {
f *= 10
}
} else if src.Exp < 0 {
for i := 0; i > int(src.Exp); i-- {
f /= 10
}
}
return f, nil
}
func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error {
if src == nil {
*dst = Numeric{Status: Null}

View File

@ -247,9 +247,12 @@ func TestNumericAssignTo(t *testing.T) {
}{
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f32, expected: float32(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f64, expected: float64(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f32, expected: float32(4.2)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f64, expected: float64(4.2)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i16, expected: int16(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i32, expected: int32(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i64, expected: int64(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Exp: 3, Status: pgtype.Present}, dst: &i64, expected: int64(42000)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i, expected: int(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)},
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)},

View File

@ -1,6 +1,7 @@
package pgtype_test
import (
"context"
"database/sql"
"fmt"
"io"
@ -125,6 +126,7 @@ func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface
func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
}
@ -175,6 +177,35 @@ func testPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []
}
}
func testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
conn := mustConnectPgx(t)
defer mustClose(t, conn)
for i, v := range values {
// Derefence value if it is a pointer
derefV := v
refVal := reflect.ValueOf(v)
if refVal.Kind() == reflect.Ptr {
derefV = refVal.Elem().Interface()
}
result := reflect.New(reflect.TypeOf(derefV))
err := conn.QueryRowEx(
context.Background(),
fmt.Sprintf("select ($1)::%s", pgTypeName),
&pgx.QueryExOptions{SimpleProtocol: true},
v,
).Scan(result.Interface())
if err != nil {
t.Errorf("Simple protocol %d: %v", i, err)
}
if !eqFunc(result.Elem().Interface(), derefV) {
t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface())
}
}
}
func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
conn := mustConnectDatabaseSQL(t, driverName)
defer mustClose(t, conn)

View File

@ -8,10 +8,23 @@ import (
)
func TestXidTranscode(t *testing.T) {
testSuccessfulTranscode(t, "xid", []interface{}{
pgTypeName := "xid"
values := []interface{}{
pgtype.Xid{Uint: 42, Status: pgtype.Present},
pgtype.Xid{Status: pgtype.Null},
})
}
eqFunc := func(a, b interface{}) bool {
return reflect.DeepEqual(a, b)
}
testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
// No direct conversion from int to xid, convert through text
testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc)
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
}
}
func TestXidSet(t *testing.T) {

View File

@ -7,6 +7,7 @@ import (
"fmt"
"time"
"github.com/jackc/pgx/internal/sanitize"
"github.com/jackc/pgx/pgtype"
)
@ -123,6 +124,17 @@ func (rows *Rows) Next() bool {
}
switch t {
case rowDescription:
rows.fields = rows.conn.rxRowDescription(r)
for i := range rows.fields {
if dt, ok := rows.conn.ConnInfo.DataTypeForOid(rows.fields[i].DataType); ok {
rows.fields[i].DataTypeName = dt.Name
rows.fields[i].FormatCode = TextFormatCode
} else {
rows.Fatal(fmt.Errorf("unknown oid: %d", rows.fields[i].DataType))
return false
}
}
case dataRow:
fieldCount := r.readInt16()
if int(fieldCount) != len(rows.fields) {
@ -341,7 +353,7 @@ func (rows *Rows) AfterClose(f func(*Rows)) {
// be returned in an error state. So it is allowed to ignore the error returned
// from Query and handle it in *Rows.
func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
return c.QueryContext(context.Background(), sql, args...)
return c.QueryEx(context.Background(), sql, nil, args...)
}
func (c *Conn) getRows(sql string, args []interface{}) *Rows {
@ -368,7 +380,11 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
return (*Row)(rows)
}
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) {
type QueryExOptions struct {
SimpleProtocol bool
}
func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
err = c.waitForPreviousCancelQuery(ctx)
if err != nil {
return nil, err
@ -384,6 +400,22 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}
}
rows.unlockConn = true
if options != nil && options.SimpleProtocol {
err = c.initContext(ctx)
if err != nil {
rows.Fatal(err)
return rows, err
}
err = c.sanitizeAndSendSimpleQuery(sql, args...)
if err != nil {
rows.Fatal(err)
return rows, err
}
return rows, nil
}
ps, ok := c.preparedStatements[sql]
if !ok {
var err error
@ -411,7 +443,32 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}
return rows, err
}
func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row {
rows, _ := c.QueryContext(ctx, sql, args...)
func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) {
if c.RuntimeParams["standard_conforming_strings"] != "on" {
return errors.New("simple protocol queries must be run with standard_conforming_strings=on")
}
if c.RuntimeParams["client_encoding"] != "UTF8" {
return errors.New("simple protocol queries must be run with client_encoding=UTF8")
}
valueArgs := make([]interface{}, len(args))
for i, a := range args {
valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a)
if err != nil {
return err
}
}
sql, err = sanitize.SanitizeSQL(sql, valueArgs...)
if err != nil {
return err
}
return c.sendSimpleQuery(sql)
}
func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row {
rows, _ := c.QueryEx(ctx, sql, options, args...)
return (*Row)(rows)
}

View File

@ -797,275 +797,6 @@ func TestQueryRowNoResults(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryRowCoreInt16Slice(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var actual []int16
tests := []struct {
sql string
expected []int16
}{
{"select $1::int2[]", []int16{1, 2, 3, 4, 5}},
{"select $1::int2[]", []int16{}},
}
for i, tt := range tests {
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v", i, err)
}
if len(actual) != len(tt.expected) {
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
}
for j := 0; j < len(actual); j++ {
if actual[j] != tt.expected[j] {
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
}
}
ensureConnValid(t, conn)
}
// Check that Scan errors when an array with a null is scanned into a core slice type
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int2[];").Scan(&actual)
if err == nil {
t.Error("Expected null to cause error when scanned into slice, but it didn't")
}
if err != nil && !(strings.Contains(err.Error(), "Cannot decode null") || strings.Contains(err.Error(), "cannot assign")) {
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
}
ensureConnValid(t, conn)
}
func TestQueryRowCoreInt32Slice(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var actual []int32
tests := []struct {
sql string
expected []int32
}{
{"select $1::int4[]", []int32{1, 2, 3, 4, 5}},
{"select $1::int4[]", []int32{}},
}
for i, tt := range tests {
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v", i, err)
}
if len(actual) != len(tt.expected) {
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
}
for j := 0; j < len(actual); j++ {
if actual[j] != tt.expected[j] {
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
}
}
ensureConnValid(t, conn)
}
// Check that Scan errors when an array with a null is scanned into a core slice type
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int4[];").Scan(&actual)
if err == nil {
t.Error("Expected null to cause error when scanned into slice, but it didn't")
}
ensureConnValid(t, conn)
}
func TestQueryRowCoreInt64Slice(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var actual []int64
tests := []struct {
sql string
expected []int64
}{
{"select $1::int8[]", []int64{1, 2, 3, 4, 5}},
{"select $1::int8[]", []int64{}},
}
for i, tt := range tests {
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v", i, err)
}
if len(actual) != len(tt.expected) {
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
}
for j := 0; j < len(actual); j++ {
if actual[j] != tt.expected[j] {
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
}
}
ensureConnValid(t, conn)
}
// Check that Scan errors when an array with a null is scanned into a core slice type
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int8[];").Scan(&actual)
if err == nil {
t.Error("Expected null to cause error when scanned into slice, but it didn't")
}
ensureConnValid(t, conn)
}
func TestQueryRowCoreFloat32Slice(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var actual []float32
tests := []struct {
sql string
expected []float32
}{
{"select $1::float4[]", []float32{1.5, 2.0, 3.5}},
{"select $1::float4[]", []float32{}},
}
for i, tt := range tests {
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v", i, err)
}
if len(actual) != len(tt.expected) {
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
}
for j := 0; j < len(actual); j++ {
if actual[j] != tt.expected[j] {
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
}
}
ensureConnValid(t, conn)
}
// Check that Scan errors when an array with a null is scanned into a core slice type
err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float4[];").Scan(&actual)
if err == nil {
t.Error("Expected null to cause error when scanned into slice, but it didn't")
}
ensureConnValid(t, conn)
}
func TestQueryRowCoreFloat64Slice(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var actual []float64
tests := []struct {
sql string
expected []float64
}{
{"select $1::float8[]", []float64{1.5, 2.0, 3.5}},
{"select $1::float8[]", []float64{}},
}
for i, tt := range tests {
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v", i, err)
}
if len(actual) != len(tt.expected) {
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
}
for j := 0; j < len(actual); j++ {
if actual[j] != tt.expected[j] {
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
}
}
ensureConnValid(t, conn)
}
// Check that Scan errors when an array with a null is scanned into a core slice type
err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float8[];").Scan(&actual)
if err == nil {
t.Error("Expected null to cause error when scanned into slice, but it didn't")
}
ensureConnValid(t, conn)
}
func TestQueryRowCoreStringSlice(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
var actual []string
tests := []struct {
sql string
expected []string
}{
{"select $1::text[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}},
{"select $1::text[]", []string{}},
{"select $1::varchar[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}},
{"select $1::varchar[]", []string{}},
}
for i, tt := range tests {
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
if err != nil {
t.Errorf("%d. Unexpected failure: %v", i, err)
}
if len(actual) != len(tt.expected) {
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
}
for j := 0; j < len(actual); j++ {
if actual[j] != tt.expected[j] {
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
}
}
ensureConnValid(t, conn)
}
// Check that Scan errors when an array with a null is scanned into a core slice type
err := conn.QueryRow("select '{Adam,Eve,NULL}'::text[];").Scan(&actual)
if err == nil {
t.Error("Expected null to cause error when scanned into slice, but it didn't")
}
ensureConnValid(t, conn)
}
func TestReadingValueAfterEmptyArray(t *testing.T) {
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
@ -1236,7 +967,7 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryContextSuccess(t *testing.T) {
func TestQueryExContextSuccess(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -1245,7 +976,7 @@ func TestQueryContextSuccess(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
rows, err := conn.QueryContext(ctx, "select 42::integer")
rows, err := conn.QueryEx(ctx, "select 42::integer", nil)
if err != nil {
t.Fatal(err)
}
@ -1273,7 +1004,7 @@ func TestQueryContextSuccess(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
func TestQueryExContextErrorWhileReceivingRows(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -1282,7 +1013,7 @@ func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
rows, err := conn.QueryContext(ctx, "select 10/(10-n) from generate_series(1, 100) n")
rows, err := conn.QueryEx(ctx, "select 10/(10-n) from generate_series(1, 100) n", nil)
if err != nil {
t.Fatal(err)
}
@ -1310,7 +1041,7 @@ func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryContextCancelationCancelsQuery(t *testing.T) {
func TestQueryExContextCancelationCancelsQuery(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -1322,7 +1053,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) {
cancelFunc()
}()
rows, err := conn.QueryContext(ctx, "select pg_sleep(5)")
rows, err := conn.QueryEx(ctx, "select pg_sleep(5)", nil)
if err != nil {
t.Fatal(err)
}
@ -1338,7 +1069,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryRowContextSuccess(t *testing.T) {
func TestQueryRowExContextSuccess(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -1348,7 +1079,7 @@ func TestQueryRowContextSuccess(t *testing.T) {
defer cancelFunc()
var result int
err := conn.QueryRowContext(ctx, "select 42::integer").Scan(&result)
err := conn.QueryRowEx(ctx, "select 42::integer", nil).Scan(&result)
if err != nil {
t.Fatal(err)
}
@ -1359,7 +1090,7 @@ func TestQueryRowContextSuccess(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
func TestQueryRowExContextErrorWhileReceivingRow(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -1369,7 +1100,7 @@ func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
defer cancelFunc()
var result int
err := conn.QueryRowContext(ctx, "select 10/0").Scan(&result)
err := conn.QueryRowEx(ctx, "select 10/0", nil).Scan(&result)
if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" {
t.Fatalf("Expected division by zero error, but got %v", err)
}
@ -1377,7 +1108,7 @@ func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
ensureConnValid(t, conn)
}
func TestQueryRowContextCancelationCancelsQuery(t *testing.T) {
func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
@ -1390,10 +1121,227 @@ func TestQueryRowContextCancelationCancelsQuery(t *testing.T) {
}()
var result []byte
err := conn.QueryRowContext(ctx, "select pg_sleep(5)").Scan(&result)
err := conn.QueryRowEx(ctx, "select pg_sleep(5)", nil).Scan(&result)
if err != context.Canceled {
t.Fatalf("Expected context.Canceled error, got %v", err)
}
ensureConnValid(t, conn)
}
func TestConnSimpleProtocol(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
// Test all supported low-level types
{
expected := int64(42)
var actual int64
err := conn.QueryRowEx(
context.Background(),
"select $1::int8",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
{
expected := float64(1.23)
var actual float64
err := conn.QueryRowEx(
context.Background(),
"select $1::float8",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
{
expected := true
var actual bool
err := conn.QueryRowEx(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
{
expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95}
var actual []byte
err := conn.QueryRowEx(
context.Background(),
"select $1::bytea",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if bytes.Compare(actual, expected) != 0 {
t.Errorf("expected %v got %v", expected, actual)
}
}
{
expected := "test"
var actual string
err := conn.QueryRowEx(
context.Background(),
"select $1::text",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
// Test high-level type
{
expected := pgtype.Line{A: 1, B: 2, C: 1.5, Status: pgtype.Present}
actual := expected
err := conn.QueryRowEx(
context.Background(),
"select $1::line",
&pgx.QueryExOptions{SimpleProtocol: true},
&expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
// Test multiple args in single query
{
expectedInt64 := int64(234423)
expectedFloat64 := float64(-0.2312)
expectedBool := true
expectedBytes := []byte{255, 0, 23, 16, 87, 45, 9, 23, 45, 223}
expectedString := "test"
var actualInt64 int64
var actualFloat64 float64
var actualBool bool
var actualBytes []byte
var actualString string
err := conn.QueryRowEx(
context.Background(),
"select $1::int8, $2::float8, $3, $4::bytea, $5::text",
&pgx.QueryExOptions{SimpleProtocol: true},
expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString,
).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString)
if err != nil {
t.Error(err)
}
if expectedInt64 != actualInt64 {
t.Errorf("expected %v got %v", expectedInt64, actualInt64)
}
if expectedFloat64 != actualFloat64 {
t.Errorf("expected %v got %v", expectedFloat64, actualFloat64)
}
if expectedBool != actualBool {
t.Errorf("expected %v got %v", expectedBool, actualBool)
}
if bytes.Compare(expectedBytes, actualBytes) != 0 {
t.Errorf("expected %v got %v", expectedBytes, actualBytes)
}
if expectedString != actualString {
t.Errorf("expected %v got %v", expectedString, actualString)
}
}
// Test dangerous cases
{
expected := "foo';drop table users;"
var actual string
err := conn.QueryRowEx(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
ensureConnValid(t, conn)
}
func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, "set client_encoding to 'SQL_ASCII'")
var expected string
err := conn.QueryRowEx(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
"test",
).Scan(&expected)
if err == nil {
t.Error("expected error when client_encoding not UTF8, but no error occurred")
}
ensureConnValid(t, conn)
}
func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
mustExec(t, conn, "set standard_conforming_strings to off")
var expected string
err := conn.QueryRowEx(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
`\'; drop table users; --`,
).Scan(&expected)
if err == nil {
t.Error("expected error when standard_conforming_strings is off, but no error occurred")
}
ensureConnValid(t, conn)
}

View File

@ -268,7 +268,7 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr
args := namedValueToInterface(argsV)
rows, err := c.conn.QueryContext(ctx, name, args...)
rows, err := c.conn.QueryEx(ctx, name, nil, args...)
if err != nil {
fmt.Println(err)
return nil, err

View File

@ -49,8 +49,8 @@ func TestStressConnPool(t *testing.T) {
{"listenAndPoolUnlistens", listenAndPoolUnlistens},
{"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }},
{"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate},
{"canceledQueryContext", canceledQueryContext},
{"canceledExecContext", canceledExecContext},
{"canceledQueryExContext", canceledQueryExContext},
{"canceledExecExContext", canceledExecExContext},
}
actionCount := 1000
@ -317,14 +317,14 @@ func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error {
return tx.Commit()
}
func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error {
func canceledQueryExContext(pool *pgx.ConnPool, actionNum int) error {
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
cancelFunc()
}()
rows, err := pool.QueryContext(ctx, "select pg_sleep(2)")
rows, err := pool.QueryEx(ctx, "select pg_sleep(2)", nil)
if err == context.Canceled {
return nil
} else if err != nil {
@ -342,14 +342,14 @@ func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error {
return nil
}
func canceledExecContext(pool *pgx.ConnPool, actionNum int) error {
func canceledExecExContext(pool *pgx.ConnPool, actionNum int) error {
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
cancelFunc()
}()
_, err := pool.ExecContext(ctx, "select pg_sleep(2)")
_, err := pool.ExecEx(ctx, "select pg_sleep(2)", nil)
if err != context.Canceled {
return fmt.Errorf("Expected context.Canceled error, got %v", err)
}

View File

@ -4,7 +4,9 @@ import (
"bytes"
"database/sql/driver"
"fmt"
"math"
"reflect"
"time"
"github.com/jackc/pgx/pgtype"
)
@ -22,6 +24,80 @@ func (e SerializationError) Error() string {
return string(e)
}
func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) {
if arg == nil {
return nil, nil
}
switch arg := arg.(type) {
case driver.Valuer:
return arg.Value()
case pgtype.TextEncoder:
buf := &bytes.Buffer{}
null, err := arg.EncodeText(ci, buf)
if err != nil {
return nil, err
}
if null {
return nil, nil
}
return buf.String(), nil
case int64:
return arg, nil
case float64:
return arg, nil
case bool:
return arg, nil
case time.Time:
return arg, nil
case string:
return arg, nil
case []byte:
return arg, nil
case int8:
return int64(arg), nil
case int16:
return int64(arg), nil
case int32:
return int64(arg), nil
case int:
return int64(arg), nil
case uint8:
return int64(arg), nil
case uint16:
return int64(arg), nil
case uint32:
return int64(arg), nil
case uint64:
if arg > math.MaxInt64 {
return nil, fmt.Errorf("arg too big for int64: %v", arg)
}
return int64(arg), nil
case uint:
if arg > math.MaxInt64 {
return nil, fmt.Errorf("arg too big for int64: %v", arg)
}
return int64(arg), nil
case float32:
return float64(arg), nil
}
refVal := reflect.ValueOf(arg)
if refVal.Kind() == reflect.Ptr {
if refVal.IsNil() {
return nil, nil
}
arg = refVal.Elem().Interface()
return convertSimpleArgument(ci, arg)
}
if strippedArg, ok := stripNamedType(&refVal); ok {
return convertSimpleArgument(ci, strippedArg)
}
return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg))
}
func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
if arg == nil {
wbuf.WriteInt32(-1)