mirror of https://github.com/jackc/pgx.git
Add simple protocol suuport with (Query|Exec)Ex
parent
54d9cbc743
commit
7b1f461ec3
66
conn.go
66
conn.go
|
@ -1021,7 +1021,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
|
|||
// Exec executes sql. sql can be either a prepared statement name or an SQL string.
|
||||
// arguments should be referenced positionally from the sql string as $1, $2, etc.
|
||||
func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
return c.ExecContext(context.Background(), sql, arguments...)
|
||||
return c.ExecEx(context.Background(), sql, nil, arguments...)
|
||||
}
|
||||
|
||||
// Processes messages that are not exclusive to one context such as
|
||||
|
@ -1364,24 +1364,16 @@ func (c *Conn) Ping() error {
|
|||
}
|
||||
|
||||
func (c *Conn) PingContext(ctx context.Context) error {
|
||||
_, err := c.ExecContext(ctx, ";")
|
||||
_, err := c.ExecEx(ctx, ";", nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
func (c *Conn) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
}()
|
||||
|
||||
if err = c.lock(); err != nil {
|
||||
return commandTag, err
|
||||
}
|
||||
|
@ -1406,8 +1398,56 @@ func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interfa
|
|||
}
|
||||
}()
|
||||
|
||||
if err = c.sendQuery(sql, arguments...); err != nil {
|
||||
return
|
||||
if options != nil && options.SimpleProtocol {
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
}()
|
||||
|
||||
err = c.sanitizeAndSendSimpleQuery(sql, arguments...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
||||
}
|
||||
} else {
|
||||
if len(arguments) > 0 {
|
||||
ps, ok := c.preparedStatements[sql]
|
||||
if !ok {
|
||||
var err error
|
||||
ps, err = c.PrepareExContext(ctx, "", sql, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
}()
|
||||
|
||||
err = c.sendPreparedQuery(ps, arguments...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
} else {
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() {
|
||||
err = c.termContext(err)
|
||||
}()
|
||||
|
||||
if err = c.sendQuery(sql, arguments...); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var softErr error
|
||||
|
|
12
conn_pool.go
12
conn_pool.go
|
@ -360,14 +360,14 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman
|
|||
return c.Exec(sql, arguments...)
|
||||
}
|
||||
|
||||
func (p *ConnPool) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
func (p *ConnPool) ExecEx(ctx context.Context, sql string, options *QueryExOptions, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
var c *Conn
|
||||
if c, err = p.Acquire(); err != nil {
|
||||
return
|
||||
}
|
||||
defer p.Release(c)
|
||||
|
||||
return c.ExecContext(ctx, sql, arguments...)
|
||||
return c.ExecEx(ctx, sql, options, arguments...)
|
||||
}
|
||||
|
||||
// Query acquires a connection and delegates the call to that connection. When
|
||||
|
@ -390,14 +390,14 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) {
|
|||
return rows, nil
|
||||
}
|
||||
|
||||
func (p *ConnPool) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) {
|
||||
func (p *ConnPool) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (*Rows, error) {
|
||||
c, err := p.Acquire()
|
||||
if err != nil {
|
||||
// Because checking for errors can be deferred to the *Rows, build one with the error
|
||||
return &Rows{closed: true, err: err}, err
|
||||
}
|
||||
|
||||
rows, err := c.QueryContext(ctx, sql, args...)
|
||||
rows, err := c.QueryEx(ctx, sql, options, args...)
|
||||
if err != nil {
|
||||
p.Release(c)
|
||||
return rows, err
|
||||
|
@ -416,8 +416,8 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row {
|
|||
return (*Row)(rows)
|
||||
}
|
||||
|
||||
func (p *ConnPool) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row {
|
||||
rows, _ := p.QueryContext(ctx, sql, args...)
|
||||
func (p *ConnPool) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row {
|
||||
rows, _ := p.QueryEx(ctx, sql, options, args...)
|
||||
return (*Row)(rows)
|
||||
}
|
||||
|
||||
|
|
80
conn_test.go
80
conn_test.go
|
@ -1033,7 +1033,7 @@ func TestExecFailure(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExecContextWithoutCancelation(t *testing.T) {
|
||||
func TestExecExContextWithoutCancelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
|
@ -1042,16 +1042,16 @@ func TestExecContextWithoutCancelation(t *testing.T) {
|
|||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
commandTag, err := conn.ExecContext(ctx, "create temporary table foo(id integer primary key);")
|
||||
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(id integer primary key);", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from ExecContext: %v", commandTag)
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecContextFailureWithoutCancelation(t *testing.T) {
|
||||
func TestExecExContextFailureWithoutCancelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
|
@ -1060,18 +1060,18 @@ func TestExecContextFailureWithoutCancelation(t *testing.T) {
|
|||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
if _, err := conn.ExecContext(ctx, "selct;"); err == nil {
|
||||
if _, err := conn.ExecEx(ctx, "selct;", nil); err == nil {
|
||||
t.Fatal("Expected SQL syntax error")
|
||||
}
|
||||
|
||||
rows, _ := conn.Query("select 1")
|
||||
rows.Close()
|
||||
if rows.Err() != nil {
|
||||
t.Fatalf("ExecContext failure appears to have broken connection: %v", rows.Err())
|
||||
t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecContextCancelationCancelsQuery(t *testing.T) {
|
||||
func TestExecExContextCancelationCancelsQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
|
@ -1083,7 +1083,7 @@ func TestExecContextCancelationCancelsQuery(t *testing.T) {
|
|||
cancelFunc()
|
||||
}()
|
||||
|
||||
_, err := conn.ExecContext(ctx, "select pg_sleep(60)")
|
||||
_, err := conn.ExecEx(ctx, "select pg_sleep(60)", nil)
|
||||
if err != context.Canceled {
|
||||
t.Fatalf("Expected context.Canceled err, got %v", err)
|
||||
}
|
||||
|
@ -1091,6 +1091,70 @@ func TestExecContextCancelationCancelsQuery(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestExecExExtendedProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
|
||||
commandTag, err = conn.ExecEx(
|
||||
ctx,
|
||||
"insert into foo(name) values($1);",
|
||||
nil,
|
||||
"bar",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "INSERT 0 1" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestExecExSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
|
||||
commandTag, err = conn.ExecEx(
|
||||
ctx,
|
||||
"insert into foo(name) values($1);",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
"bar'; drop table foo;--",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "INSERT 0 1" {
|
||||
t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -0,0 +1,236 @@
|
|||
package sanitize
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Part is either a string or an int. A string is raw SQL. An int is a
|
||||
// argument placeholder.
|
||||
type Part interface{}
|
||||
|
||||
type Query struct {
|
||||
Parts []Part
|
||||
}
|
||||
|
||||
func (q *Query) Sanitize(args ...interface{}) (string, error) {
|
||||
argUse := make([]bool, len(args))
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
for _, part := range q.Parts {
|
||||
var str string
|
||||
switch part := part.(type) {
|
||||
case string:
|
||||
str = part
|
||||
case int:
|
||||
argIdx := part - 1
|
||||
if argIdx >= len(args) {
|
||||
return "", fmt.Errorf("insufficient arguments")
|
||||
}
|
||||
arg := args[argIdx]
|
||||
switch arg := arg.(type) {
|
||||
case nil:
|
||||
str = "null"
|
||||
case int64:
|
||||
str = strconv.FormatInt(arg, 10)
|
||||
case float64:
|
||||
str = strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
case bool:
|
||||
str = strconv.FormatBool(arg)
|
||||
case []byte:
|
||||
str = QuoteBytes(arg)
|
||||
case string:
|
||||
str = QuoteString(arg)
|
||||
case time.Time:
|
||||
str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||
default:
|
||||
return "", fmt.Errorf("invalid arg type: %T", arg)
|
||||
}
|
||||
argUse[argIdx] = true
|
||||
default:
|
||||
return "", fmt.Errorf("invalid Part type: %T", part)
|
||||
}
|
||||
buf.WriteString(str)
|
||||
}
|
||||
|
||||
for i, used := range argUse {
|
||||
if !used {
|
||||
return "", fmt.Errorf("unused argument: %d", i)
|
||||
}
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func NewQuery(sql string) (*Query, error) {
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
}
|
||||
|
||||
for l.stateFn != nil {
|
||||
l.stateFn = l.stateFn(l)
|
||||
}
|
||||
|
||||
query := &Query{Parts: l.parts}
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func QuoteString(str string) string {
|
||||
return "'" + strings.Replace(str, "'", "''", -1) + "'"
|
||||
}
|
||||
|
||||
func QuoteBytes(buf []byte) string {
|
||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||
}
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
stateFn stateFn
|
||||
parts []Part
|
||||
}
|
||||
|
||||
type stateFn func(*sqlLexer) stateFn
|
||||
|
||||
func rawState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case 'e', 'E':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '\'' {
|
||||
l.pos += width
|
||||
return escapeStringState
|
||||
}
|
||||
case '\'':
|
||||
return singleQuoteState
|
||||
case '"':
|
||||
return doubleQuoteState
|
||||
case '$':
|
||||
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if '0' <= nextRune && nextRune <= '9' {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos-width])
|
||||
}
|
||||
l.start = l.pos
|
||||
return placeholderState
|
||||
}
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func singleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doubleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '"':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '"' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// placeholderState consumes a placeholder value. The $ must have already has
|
||||
// already been consumed. The first rune must be a digit.
|
||||
func placeholderState(l *sqlLexer) stateFn {
|
||||
num := 0
|
||||
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
if '0' <= r && r <= '9' {
|
||||
num *= 10
|
||||
num += int(r - '0')
|
||||
} else {
|
||||
l.parts = append(l.parts, num)
|
||||
l.pos -= width
|
||||
l.start = l.pos
|
||||
return rawState
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func escapeStringState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||
// as necessary. This function is only safe when standard_conforming_strings is
|
||||
// on.
|
||||
func SanitizeSQL(sql string, args ...interface{}) (string, error) {
|
||||
query, err := NewQuery(sql)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return query.Sanitize(args...)
|
||||
}
|
|
@ -0,0 +1,175 @@
|
|||
package sanitize_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/internal/sanitize"
|
||||
)
|
||||
|
||||
func TestNewQuery(t *testing.T) {
|
||||
successTests := []struct {
|
||||
sql string
|
||||
expected sanitize.Query
|
||||
}{
|
||||
{
|
||||
sql: "select 42",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
|
||||
},
|
||||
{
|
||||
sql: "select $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
},
|
||||
{
|
||||
sql: "select 'quoted $42', $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}},
|
||||
},
|
||||
{
|
||||
sql: `select "doubled quoted $42", $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}},
|
||||
},
|
||||
{
|
||||
sql: "select 'foo''bar', $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}},
|
||||
},
|
||||
{
|
||||
sql: `select "foo""bar", $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}},
|
||||
},
|
||||
{
|
||||
sql: "select '''', $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}},
|
||||
},
|
||||
{
|
||||
sql: `select """", $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}},
|
||||
},
|
||||
{
|
||||
sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}},
|
||||
},
|
||||
{
|
||||
sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}},
|
||||
},
|
||||
{
|
||||
sql: `select E'escape string\' $42', $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}},
|
||||
},
|
||||
{
|
||||
sql: `select e'escape string\' $42', $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successTests {
|
||||
query, err := sanitize.NewQuery(tt.sql)
|
||||
if err != nil {
|
||||
t.Errorf("%d. %v", i, err)
|
||||
}
|
||||
|
||||
if len(query.Parts) == len(tt.expected.Parts) {
|
||||
for j := range query.Parts {
|
||||
if query.Parts[j] != tt.expected.Parts[j] {
|
||||
t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuerySanitize(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
query sanitize.Query
|
||||
args []interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
|
||||
args: []interface{}{},
|
||||
expected: `select 42`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `select 42`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{float64(1.23)},
|
||||
expected: `select 1.23`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{true},
|
||||
expected: `select true`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{[]byte{0, 1, 2, 3, 255}},
|
||||
expected: `select '\x00010203ff'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{nil},
|
||||
expected: `select null`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{"foobar"},
|
||||
expected: `select 'foobar'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{"foo'bar"},
|
||||
expected: `select 'foo''bar'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{`foo\'bar`},
|
||||
expected: `select 'foo\''bar'`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
actual, err := tt.query.Sanitize(tt.args...)
|
||||
if err != nil {
|
||||
t.Errorf("%d. %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if tt.expected != actual {
|
||||
t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
errorTests := []struct {
|
||||
query sanitize.Query
|
||||
args []interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `insufficient arguments`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}},
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `unused argument: 0`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{42},
|
||||
expected: `invalid arg type: int`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range errorTests {
|
||||
_, err := tt.query.Sanitize(tt.args...)
|
||||
if err == nil || err.Error() != tt.expected {
|
||||
t.Errorf("%d. expected error %v, got %v", i, tt.expected, err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -8,10 +8,23 @@ import (
|
|||
)
|
||||
|
||||
func TestCidTranscode(t *testing.T) {
|
||||
testSuccessfulTranscode(t, "cid", []interface{}{
|
||||
pgTypeName := "cid"
|
||||
values := []interface{}{
|
||||
pgtype.Cid{Uint: 42, Status: pgtype.Present},
|
||||
pgtype.Cid{Status: pgtype.Null},
|
||||
})
|
||||
}
|
||||
eqFunc := func(a, b interface{}) bool {
|
||||
return reflect.DeepEqual(a, b)
|
||||
}
|
||||
|
||||
testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
||||
|
||||
// No direct conversion from int to cid, convert through text
|
||||
testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc)
|
||||
|
||||
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
||||
testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCidSet(t *testing.T) {
|
||||
|
|
|
@ -145,7 +145,7 @@ func (dst *Json) Scan(src interface{}) error {
|
|||
func (src Json) Value() (driver.Value, error) {
|
||||
switch src.Status {
|
||||
case Present:
|
||||
return src.Bytes, nil
|
||||
return string(src.Bytes), nil
|
||||
case Null:
|
||||
return nil, nil
|
||||
default:
|
||||
|
|
|
@ -121,13 +121,13 @@ func (src *Numeric) AssignTo(dst interface{}) error {
|
|||
case Present:
|
||||
switch v := dst.(type) {
|
||||
case *float32:
|
||||
f, err := strconv.ParseFloat(src.Int.String(), 64)
|
||||
f, err := src.toFloat64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return float64AssignTo(f, src.Status, dst)
|
||||
case *float64:
|
||||
f, err := strconv.ParseFloat(src.Int.String(), 64)
|
||||
f, err := src.toFloat64()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -283,6 +283,23 @@ func (dst *Numeric) toBigInt() (*big.Int, error) {
|
|||
return num, nil
|
||||
}
|
||||
|
||||
func (src *Numeric) toFloat64() (float64, error) {
|
||||
f, err := strconv.ParseFloat(src.Int.String(), 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if src.Exp > 0 {
|
||||
for i := 0; i < int(src.Exp); i++ {
|
||||
f *= 10
|
||||
}
|
||||
} else if src.Exp < 0 {
|
||||
for i := 0; i > int(src.Exp); i-- {
|
||||
f /= 10
|
||||
}
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (dst *Numeric) DecodeText(ci *ConnInfo, src []byte) error {
|
||||
if src == nil {
|
||||
*dst = Numeric{Status: Null}
|
||||
|
|
|
@ -247,9 +247,12 @@ func TestNumericAssignTo(t *testing.T) {
|
|||
}{
|
||||
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f32, expected: float32(42)},
|
||||
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &f64, expected: float64(42)},
|
||||
{src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f32, expected: float32(4.2)},
|
||||
{src: &pgtype.Numeric{Int: big.NewInt(42), Exp: -1, Status: pgtype.Present}, dst: &f64, expected: float64(4.2)},
|
||||
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i16, expected: int16(42)},
|
||||
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i32, expected: int32(42)},
|
||||
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i64, expected: int64(42)},
|
||||
{src: &pgtype.Numeric{Int: big.NewInt(42), Exp: 3, Status: pgtype.Present}, dst: &i64, expected: int64(42000)},
|
||||
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &i, expected: int(42)},
|
||||
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui8, expected: uint8(42)},
|
||||
{src: &pgtype.Numeric{Int: big.NewInt(42), Status: pgtype.Present}, dst: &ui16, expected: uint16(42)},
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package pgtype_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -125,6 +126,7 @@ func testSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface
|
|||
|
||||
func testSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
||||
testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
||||
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
||||
testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
|
||||
}
|
||||
|
@ -175,6 +177,35 @@ func testPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []
|
|||
}
|
||||
}
|
||||
|
||||
func testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
conn := mustConnectPgx(t)
|
||||
defer mustClose(t, conn)
|
||||
|
||||
for i, v := range values {
|
||||
// Derefence value if it is a pointer
|
||||
derefV := v
|
||||
refVal := reflect.ValueOf(v)
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
derefV = refVal.Elem().Interface()
|
||||
}
|
||||
|
||||
result := reflect.New(reflect.TypeOf(derefV))
|
||||
err := conn.QueryRowEx(
|
||||
context.Background(),
|
||||
fmt.Sprintf("select ($1)::%s", pgTypeName),
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
v,
|
||||
).Scan(result.Interface())
|
||||
if err != nil {
|
||||
t.Errorf("Simple protocol %d: %v", i, err)
|
||||
}
|
||||
|
||||
if !eqFunc(result.Elem().Interface(), derefV) {
|
||||
t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
conn := mustConnectDatabaseSQL(t, driverName)
|
||||
defer mustClose(t, conn)
|
||||
|
|
|
@ -8,10 +8,23 @@ import (
|
|||
)
|
||||
|
||||
func TestXidTranscode(t *testing.T) {
|
||||
testSuccessfulTranscode(t, "xid", []interface{}{
|
||||
pgTypeName := "xid"
|
||||
values := []interface{}{
|
||||
pgtype.Xid{Uint: 42, Status: pgtype.Present},
|
||||
pgtype.Xid{Status: pgtype.Null},
|
||||
})
|
||||
}
|
||||
eqFunc := func(a, b interface{}) bool {
|
||||
return reflect.DeepEqual(a, b)
|
||||
}
|
||||
|
||||
testPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
||||
|
||||
// No direct conversion from int to xid, convert through text
|
||||
testPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc)
|
||||
|
||||
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
||||
testDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestXidSet(t *testing.T) {
|
||||
|
|
65
query.go
65
query.go
|
@ -7,6 +7,7 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/internal/sanitize"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
|
@ -123,6 +124,17 @@ func (rows *Rows) Next() bool {
|
|||
}
|
||||
|
||||
switch t {
|
||||
case rowDescription:
|
||||
rows.fields = rows.conn.rxRowDescription(r)
|
||||
for i := range rows.fields {
|
||||
if dt, ok := rows.conn.ConnInfo.DataTypeForOid(rows.fields[i].DataType); ok {
|
||||
rows.fields[i].DataTypeName = dt.Name
|
||||
rows.fields[i].FormatCode = TextFormatCode
|
||||
} else {
|
||||
rows.Fatal(fmt.Errorf("unknown oid: %d", rows.fields[i].DataType))
|
||||
return false
|
||||
}
|
||||
}
|
||||
case dataRow:
|
||||
fieldCount := r.readInt16()
|
||||
if int(fieldCount) != len(rows.fields) {
|
||||
|
@ -341,7 +353,7 @@ func (rows *Rows) AfterClose(f func(*Rows)) {
|
|||
// be returned in an error state. So it is allowed to ignore the error returned
|
||||
// from Query and handle it in *Rows.
|
||||
func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) {
|
||||
return c.QueryContext(context.Background(), sql, args...)
|
||||
return c.QueryEx(context.Background(), sql, nil, args...)
|
||||
}
|
||||
|
||||
func (c *Conn) getRows(sql string, args []interface{}) *Rows {
|
||||
|
@ -368,7 +380,11 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
|
|||
return (*Row)(rows)
|
||||
}
|
||||
|
||||
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) {
|
||||
type QueryExOptions struct {
|
||||
SimpleProtocol bool
|
||||
}
|
||||
|
||||
func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) (rows *Rows, err error) {
|
||||
err = c.waitForPreviousCancelQuery(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -384,6 +400,22 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}
|
|||
}
|
||||
rows.unlockConn = true
|
||||
|
||||
if options != nil && options.SimpleProtocol {
|
||||
err = c.initContext(ctx)
|
||||
if err != nil {
|
||||
rows.Fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
err = c.sanitizeAndSendSimpleQuery(sql, args...)
|
||||
if err != nil {
|
||||
rows.Fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
ps, ok := c.preparedStatements[sql]
|
||||
if !ok {
|
||||
var err error
|
||||
|
@ -411,7 +443,32 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}
|
|||
return rows, err
|
||||
}
|
||||
|
||||
func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row {
|
||||
rows, _ := c.QueryContext(ctx, sql, args...)
|
||||
func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) {
|
||||
if c.RuntimeParams["standard_conforming_strings"] != "on" {
|
||||
return errors.New("simple protocol queries must be run with standard_conforming_strings=on")
|
||||
}
|
||||
|
||||
if c.RuntimeParams["client_encoding"] != "UTF8" {
|
||||
return errors.New("simple protocol queries must be run with client_encoding=UTF8")
|
||||
}
|
||||
|
||||
valueArgs := make([]interface{}, len(args))
|
||||
for i, a := range args {
|
||||
valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
sql, err = sanitize.SanitizeSQL(sql, valueArgs...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.sendSimpleQuery(sql)
|
||||
}
|
||||
|
||||
func (c *Conn) QueryRowEx(ctx context.Context, sql string, options *QueryExOptions, args ...interface{}) *Row {
|
||||
rows, _ := c.QueryEx(ctx, sql, options, args...)
|
||||
return (*Row)(rows)
|
||||
}
|
||||
|
|
510
query_test.go
510
query_test.go
|
@ -797,275 +797,6 @@ func TestQueryRowNoResults(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowCoreInt16Slice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var actual []int16
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected []int16
|
||||
}{
|
||||
{"select $1::int2[]", []int16{1, 2, 3, 4, 5}},
|
||||
{"select $1::int2[]", []int16{}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||
}
|
||||
|
||||
if len(actual) != len(tt.expected) {
|
||||
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||
}
|
||||
|
||||
for j := 0; j < len(actual); j++ {
|
||||
if actual[j] != tt.expected[j] {
|
||||
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int2[];").Scan(&actual)
|
||||
if err == nil {
|
||||
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||
}
|
||||
if err != nil && !(strings.Contains(err.Error(), "Cannot decode null") || strings.Contains(err.Error(), "cannot assign")) {
|
||||
t.Errorf(`Expected null to cause error "Cannot decode null..." but it was %v`, err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowCoreInt32Slice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var actual []int32
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected []int32
|
||||
}{
|
||||
{"select $1::int4[]", []int32{1, 2, 3, 4, 5}},
|
||||
{"select $1::int4[]", []int32{}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||
}
|
||||
|
||||
if len(actual) != len(tt.expected) {
|
||||
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||
}
|
||||
|
||||
for j := 0; j < len(actual); j++ {
|
||||
if actual[j] != tt.expected[j] {
|
||||
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int4[];").Scan(&actual)
|
||||
if err == nil {
|
||||
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowCoreInt64Slice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var actual []int64
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected []int64
|
||||
}{
|
||||
{"select $1::int8[]", []int64{1, 2, 3, 4, 5}},
|
||||
{"select $1::int8[]", []int64{}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||
}
|
||||
|
||||
if len(actual) != len(tt.expected) {
|
||||
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||
}
|
||||
|
||||
for j := 0; j < len(actual); j++ {
|
||||
if actual[j] != tt.expected[j] {
|
||||
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||
err := conn.QueryRow("select '{1, 2, 3, 4, 5, null}'::int8[];").Scan(&actual)
|
||||
if err == nil {
|
||||
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowCoreFloat32Slice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var actual []float32
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected []float32
|
||||
}{
|
||||
{"select $1::float4[]", []float32{1.5, 2.0, 3.5}},
|
||||
{"select $1::float4[]", []float32{}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||
}
|
||||
|
||||
if len(actual) != len(tt.expected) {
|
||||
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||
}
|
||||
|
||||
for j := 0; j < len(actual); j++ {
|
||||
if actual[j] != tt.expected[j] {
|
||||
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||
err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float4[];").Scan(&actual)
|
||||
if err == nil {
|
||||
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowCoreFloat64Slice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var actual []float64
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected []float64
|
||||
}{
|
||||
{"select $1::float8[]", []float64{1.5, 2.0, 3.5}},
|
||||
{"select $1::float8[]", []float64{}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||
}
|
||||
|
||||
if len(actual) != len(tt.expected) {
|
||||
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||
}
|
||||
|
||||
for j := 0; j < len(actual); j++ {
|
||||
if actual[j] != tt.expected[j] {
|
||||
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||
err := conn.QueryRow("select '{1.5, 2.0, 3.5, null}'::float8[];").Scan(&actual)
|
||||
if err == nil {
|
||||
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowCoreStringSlice(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var actual []string
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
expected []string
|
||||
}{
|
||||
{"select $1::text[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}},
|
||||
{"select $1::text[]", []string{}},
|
||||
{"select $1::varchar[]", []string{"Adam", "Eve", "UTF-8 Characters Å Æ Ë Ͽ"}},
|
||||
{"select $1::varchar[]", []string{}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
err := conn.QueryRow(tt.sql, tt.expected).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Errorf("%d. Unexpected failure: %v", i, err)
|
||||
}
|
||||
|
||||
if len(actual) != len(tt.expected) {
|
||||
t.Errorf("%d. Expected %v, got %v", i, tt.expected, actual)
|
||||
}
|
||||
|
||||
for j := 0; j < len(actual); j++ {
|
||||
if actual[j] != tt.expected[j] {
|
||||
t.Errorf("%d. Expected actual[%d] to be %v, got %v", i, j, tt.expected[j], actual[j])
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
// Check that Scan errors when an array with a null is scanned into a core slice type
|
||||
err := conn.QueryRow("select '{Adam,Eve,NULL}'::text[];").Scan(&actual)
|
||||
if err == nil {
|
||||
t.Error("Expected null to cause error when scanned into slice, but it didn't")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestReadingValueAfterEmptyArray(t *testing.T) {
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
@ -1236,7 +967,7 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryContextSuccess(t *testing.T) {
|
||||
func TestQueryExContextSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
|
@ -1245,7 +976,7 @@ func TestQueryContextSuccess(t *testing.T) {
|
|||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
rows, err := conn.QueryContext(ctx, "select 42::integer")
|
||||
rows, err := conn.QueryEx(ctx, "select 42::integer", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1273,7 +1004,7 @@ func TestQueryContextSuccess(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
|
||||
func TestQueryExContextErrorWhileReceivingRows(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
|
@ -1282,7 +1013,7 @@ func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
|
|||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
rows, err := conn.QueryContext(ctx, "select 10/(10-n) from generate_series(1, 100) n")
|
||||
rows, err := conn.QueryEx(ctx, "select 10/(10-n) from generate_series(1, 100) n", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1310,7 +1041,7 @@ func TestQueryContextErrorWhileReceivingRows(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryContextCancelationCancelsQuery(t *testing.T) {
|
||||
func TestQueryExContextCancelationCancelsQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
|
@ -1322,7 +1053,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) {
|
|||
cancelFunc()
|
||||
}()
|
||||
|
||||
rows, err := conn.QueryContext(ctx, "select pg_sleep(5)")
|
||||
rows, err := conn.QueryEx(ctx, "select pg_sleep(5)", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1338,7 +1069,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowContextSuccess(t *testing.T) {
|
||||
func TestQueryRowExContextSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
|
@ -1348,7 +1079,7 @@ func TestQueryRowContextSuccess(t *testing.T) {
|
|||
defer cancelFunc()
|
||||
|
||||
var result int
|
||||
err := conn.QueryRowContext(ctx, "select 42::integer").Scan(&result)
|
||||
err := conn.QueryRowEx(ctx, "select 42::integer", nil).Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -1359,7 +1090,7 @@ func TestQueryRowContextSuccess(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
|
||||
func TestQueryRowExContextErrorWhileReceivingRow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
|
@ -1369,7 +1100,7 @@ func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
|
|||
defer cancelFunc()
|
||||
|
||||
var result int
|
||||
err := conn.QueryRowContext(ctx, "select 10/0").Scan(&result)
|
||||
err := conn.QueryRowEx(ctx, "select 10/0", nil).Scan(&result)
|
||||
if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" {
|
||||
t.Fatalf("Expected division by zero error, but got %v", err)
|
||||
}
|
||||
|
@ -1377,7 +1108,7 @@ func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowContextCancelationCancelsQuery(t *testing.T) {
|
||||
func TestQueryRowExContextCancelationCancelsQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
|
@ -1390,10 +1121,227 @@ func TestQueryRowContextCancelationCancelsQuery(t *testing.T) {
|
|||
}()
|
||||
|
||||
var result []byte
|
||||
err := conn.QueryRowContext(ctx, "select pg_sleep(5)").Scan(&result)
|
||||
err := conn.QueryRowEx(ctx, "select pg_sleep(5)", nil).Scan(&result)
|
||||
if err != context.Canceled {
|
||||
t.Fatalf("Expected context.Canceled error, got %v", err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
// Test all supported low-level types
|
||||
|
||||
{
|
||||
expected := int64(42)
|
||||
var actual int64
|
||||
err := conn.QueryRowEx(
|
||||
context.Background(),
|
||||
"select $1::int8",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := float64(1.23)
|
||||
var actual float64
|
||||
err := conn.QueryRowEx(
|
||||
context.Background(),
|
||||
"select $1::float8",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := true
|
||||
var actual bool
|
||||
err := conn.QueryRowEx(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95}
|
||||
var actual []byte
|
||||
err := conn.QueryRowEx(
|
||||
context.Background(),
|
||||
"select $1::bytea",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if bytes.Compare(actual, expected) != 0 {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := "test"
|
||||
var actual string
|
||||
err := conn.QueryRowEx(
|
||||
context.Background(),
|
||||
"select $1::text",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
// Test high-level type
|
||||
|
||||
{
|
||||
expected := pgtype.Line{A: 1, B: 2, C: 1.5, Status: pgtype.Present}
|
||||
actual := expected
|
||||
err := conn.QueryRowEx(
|
||||
context.Background(),
|
||||
"select $1::line",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
&expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
// Test multiple args in single query
|
||||
|
||||
{
|
||||
expectedInt64 := int64(234423)
|
||||
expectedFloat64 := float64(-0.2312)
|
||||
expectedBool := true
|
||||
expectedBytes := []byte{255, 0, 23, 16, 87, 45, 9, 23, 45, 223}
|
||||
expectedString := "test"
|
||||
var actualInt64 int64
|
||||
var actualFloat64 float64
|
||||
var actualBool bool
|
||||
var actualBytes []byte
|
||||
var actualString string
|
||||
err := conn.QueryRowEx(
|
||||
context.Background(),
|
||||
"select $1::int8, $2::float8, $3, $4::bytea, $5::text",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString,
|
||||
).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expectedInt64 != actualInt64 {
|
||||
t.Errorf("expected %v got %v", expectedInt64, actualInt64)
|
||||
}
|
||||
if expectedFloat64 != actualFloat64 {
|
||||
t.Errorf("expected %v got %v", expectedFloat64, actualFloat64)
|
||||
}
|
||||
if expectedBool != actualBool {
|
||||
t.Errorf("expected %v got %v", expectedBool, actualBool)
|
||||
}
|
||||
if bytes.Compare(expectedBytes, actualBytes) != 0 {
|
||||
t.Errorf("expected %v got %v", expectedBytes, actualBytes)
|
||||
}
|
||||
if expectedString != actualString {
|
||||
t.Errorf("expected %v got %v", expectedString, actualString)
|
||||
}
|
||||
}
|
||||
|
||||
// Test dangerous cases
|
||||
|
||||
{
|
||||
expected := "foo';drop table users;"
|
||||
var actual string
|
||||
err := conn.QueryRowEx(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "set client_encoding to 'SQL_ASCII'")
|
||||
|
||||
var expected string
|
||||
err := conn.QueryRowEx(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
"test",
|
||||
).Scan(&expected)
|
||||
if err == nil {
|
||||
t.Error("expected error when client_encoding not UTF8, but no error occurred")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "set standard_conforming_strings to off")
|
||||
|
||||
var expected string
|
||||
err := conn.QueryRowEx(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
`\'; drop table users; --`,
|
||||
).Scan(&expected)
|
||||
if err == nil {
|
||||
t.Error("expected error when standard_conforming_strings is off, but no error occurred")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
|
|
@ -268,7 +268,7 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr
|
|||
|
||||
args := namedValueToInterface(argsV)
|
||||
|
||||
rows, err := c.conn.QueryContext(ctx, name, args...)
|
||||
rows, err := c.conn.QueryEx(ctx, name, nil, args...)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return nil, err
|
||||
|
|
|
@ -49,8 +49,8 @@ func TestStressConnPool(t *testing.T) {
|
|||
{"listenAndPoolUnlistens", listenAndPoolUnlistens},
|
||||
{"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }},
|
||||
{"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate},
|
||||
{"canceledQueryContext", canceledQueryContext},
|
||||
{"canceledExecContext", canceledExecContext},
|
||||
{"canceledQueryExContext", canceledQueryExContext},
|
||||
{"canceledExecExContext", canceledExecExContext},
|
||||
}
|
||||
|
||||
actionCount := 1000
|
||||
|
@ -317,14 +317,14 @@ func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error {
|
|||
return tx.Commit()
|
||||
}
|
||||
|
||||
func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error {
|
||||
func canceledQueryExContext(pool *pgx.ConnPool, actionNum int) error {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
rows, err := pool.QueryContext(ctx, "select pg_sleep(2)")
|
||||
rows, err := pool.QueryEx(ctx, "select pg_sleep(2)", nil)
|
||||
if err == context.Canceled {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
|
@ -342,14 +342,14 @@ func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func canceledExecContext(pool *pgx.ConnPool, actionNum int) error {
|
||||
func canceledExecExContext(pool *pgx.ConnPool, actionNum int) error {
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond)
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
_, err := pool.ExecContext(ctx, "select pg_sleep(2)")
|
||||
_, err := pool.ExecEx(ctx, "select pg_sleep(2)", nil)
|
||||
if err != context.Canceled {
|
||||
return fmt.Errorf("Expected context.Canceled error, got %v", err)
|
||||
}
|
||||
|
|
76
values.go
76
values.go
|
@ -4,7 +4,9 @@ import (
|
|||
"bytes"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
@ -22,6 +24,80 @@ func (e SerializationError) Error() string {
|
|||
return string(e)
|
||||
}
|
||||
|
||||
func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, error) {
|
||||
if arg == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch arg := arg.(type) {
|
||||
case driver.Valuer:
|
||||
return arg.Value()
|
||||
case pgtype.TextEncoder:
|
||||
buf := &bytes.Buffer{}
|
||||
null, err := arg.EncodeText(ci, buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if null {
|
||||
return nil, nil
|
||||
}
|
||||
return buf.String(), nil
|
||||
case int64:
|
||||
return arg, nil
|
||||
case float64:
|
||||
return arg, nil
|
||||
case bool:
|
||||
return arg, nil
|
||||
case time.Time:
|
||||
return arg, nil
|
||||
case string:
|
||||
return arg, nil
|
||||
case []byte:
|
||||
return arg, nil
|
||||
case int8:
|
||||
return int64(arg), nil
|
||||
case int16:
|
||||
return int64(arg), nil
|
||||
case int32:
|
||||
return int64(arg), nil
|
||||
case int:
|
||||
return int64(arg), nil
|
||||
case uint8:
|
||||
return int64(arg), nil
|
||||
case uint16:
|
||||
return int64(arg), nil
|
||||
case uint32:
|
||||
return int64(arg), nil
|
||||
case uint64:
|
||||
if arg > math.MaxInt64 {
|
||||
return nil, fmt.Errorf("arg too big for int64: %v", arg)
|
||||
}
|
||||
return int64(arg), nil
|
||||
case uint:
|
||||
if arg > math.MaxInt64 {
|
||||
return nil, fmt.Errorf("arg too big for int64: %v", arg)
|
||||
}
|
||||
return int64(arg), nil
|
||||
case float32:
|
||||
return float64(arg), nil
|
||||
}
|
||||
|
||||
refVal := reflect.ValueOf(arg)
|
||||
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
if refVal.IsNil() {
|
||||
return nil, nil
|
||||
}
|
||||
arg = refVal.Elem().Interface()
|
||||
return convertSimpleArgument(ci, arg)
|
||||
}
|
||||
|
||||
if strippedArg, ok := stripNamedType(&refVal); ok {
|
||||
return convertSimpleArgument(ci, strippedArg)
|
||||
}
|
||||
return nil, SerializationError(fmt.Sprintf("Cannot encode %T in simple protocol - %T must implement driver.Valuer, pgtype.TextEncoder, or be a native type", arg, arg))
|
||||
}
|
||||
|
||||
func encodePreparedStatementArgument(wbuf *WriteBuf, oid pgtype.Oid, arg interface{}) error {
|
||||
if arg == nil {
|
||||
wbuf.WriteInt32(-1)
|
||||
|
|
Loading…
Reference in New Issue