mirror of https://github.com/jackc/pgx.git
Restore simple protocol support
parent
6d23b58b01
commit
29f02807b0
92
conn.go
92
conn.go
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgproto3/v2"
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jackc/pgx/v4/internal/sanitize"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -24,6 +25,14 @@ type ConnConfig struct {
|
|||
pgconn.Config
|
||||
Logger Logger
|
||||
LogLevel LogLevel
|
||||
|
||||
// PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended
|
||||
// protocol. This can improve performance due to being able to use the binary format. It also does not rely on client
|
||||
// side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement)
|
||||
// and may be incompatible proxies such as PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be
|
||||
// used by default. The same functionality can be controlled on a per query basis by setting
|
||||
// QueryExOptions.SimpleProtocol.
|
||||
PreferSimpleProtocol bool
|
||||
}
|
||||
|
||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
|
@ -390,6 +399,36 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (
|
|||
}
|
||||
|
||||
func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
|
||||
simpleProtocol := c.config.PreferSimpleProtocol
|
||||
|
||||
optionLoop:
|
||||
for len(arguments) > 0 {
|
||||
switch arg := arguments[0].(type) {
|
||||
case QuerySimpleProtocol:
|
||||
simpleProtocol = bool(arg)
|
||||
arguments = arguments[1:]
|
||||
default:
|
||||
break optionLoop
|
||||
}
|
||||
}
|
||||
|
||||
if simpleProtocol {
|
||||
sql, err = c.sanitizeForSimpleQuery(sql, arguments...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mrr := c.pgConn.Exec(ctx, sql)
|
||||
if mrr.NextResult() {
|
||||
result := mrr.ResultReader().Read()
|
||||
err = mrr.Close()
|
||||
return result.CommandTag, err
|
||||
} else {
|
||||
err = mrr.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
c.eqb.Reset()
|
||||
|
||||
if ps, ok := c.preparedStatements[sql]; ok {
|
||||
|
@ -495,6 +534,9 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows {
|
|||
return r
|
||||
}
|
||||
|
||||
// QuerySimpleProtocol controls whether the simple or extended protocol is used to send the query.
|
||||
type QuerySimpleProtocol bool
|
||||
|
||||
// QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position.
|
||||
type QueryResultFormats []int16
|
||||
|
||||
|
@ -506,6 +548,7 @@ type QueryResultFormatsByOID map[pgtype.OID]int16
|
|||
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
|
||||
var resultFormats QueryResultFormats
|
||||
var resultFormatsByOID QueryResultFormatsByOID
|
||||
simpleProtocol := c.config.PreferSimpleProtocol
|
||||
|
||||
optionLoop:
|
||||
for len(args) > 0 {
|
||||
|
@ -516,14 +559,39 @@ optionLoop:
|
|||
case QueryResultFormatsByOID:
|
||||
resultFormatsByOID = arg
|
||||
args = args[1:]
|
||||
case QuerySimpleProtocol:
|
||||
simpleProtocol = bool(arg)
|
||||
args = args[1:]
|
||||
default:
|
||||
break optionLoop
|
||||
}
|
||||
}
|
||||
|
||||
c.eqb.Reset()
|
||||
rows := c.getRows(sql, args)
|
||||
|
||||
var err error
|
||||
if simpleProtocol {
|
||||
sql, err = c.sanitizeForSimpleQuery(sql, args...)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
mrr := c.pgConn.Exec(ctx, sql)
|
||||
if mrr.NextResult() {
|
||||
rows.resultReader = mrr.ResultReader()
|
||||
rows.multiResultReader = mrr
|
||||
} else {
|
||||
err = mrr.Close()
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
c.eqb.Reset()
|
||||
|
||||
ps, ok := c.preparedStatements[sql]
|
||||
if !ok {
|
||||
psd, err := c.pgConn.Prepare(ctx, "", sql, nil)
|
||||
|
@ -550,7 +618,6 @@ optionLoop:
|
|||
}
|
||||
rows.sql = ps.SQL
|
||||
|
||||
var err error
|
||||
args, err = convertDriverValuers(args)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
|
@ -663,3 +730,24 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
|
|||
mrr: mrr,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) {
|
||||
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
|
||||
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
|
||||
}
|
||||
|
||||
if c.pgConn.ParameterStatus("client_encoding") != "UTF8" {
|
||||
return "", errors.New("simple protocol queries must be run with client_encoding=UTF8")
|
||||
}
|
||||
|
||||
var err error
|
||||
valueArgs := make([]interface{}, len(args))
|
||||
for i, a := range args {
|
||||
valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return sanitize.SanitizeSQL(sql, valueArgs...)
|
||||
}
|
||||
|
|
56
conn_test.go
56
conn_test.go
|
@ -71,6 +71,31 @@ func TestConnect(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConnectWithPreferSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
connConfig.PreferSimpleProtocol = true
|
||||
|
||||
conn := mustConnect(t, connConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
// If simple protocol is used we should be able to correctly scan the result
|
||||
// into a pgtype.Text as the integer will have been encoded in text.
|
||||
|
||||
var s pgtype.Text
|
||||
err := conn.QueryRow(context.Background(), "select $1::int4", 42).Scan(&s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if s.Get() != "42" {
|
||||
t.Fatalf(`expected "42", got %v`, s)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestExec(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -251,6 +276,37 @@ func TestExecExtendedProtocol(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestExecSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
commandTag, err := conn.Exec(ctx, "create temporary table foo(name varchar primary key);")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(commandTag) != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
||||
}
|
||||
|
||||
commandTag, err = conn.Exec(ctx,
|
||||
"insert into foo(name) values($1);",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
"bar'; drop table foo;--",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(commandTag) != "INSERT 0 1" {
|
||||
t.Fatalf("Unexpected results from Exec: %v", commandTag)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestPrepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -0,0 +1,237 @@
|
|||
package sanitize
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
errors "golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// Part is either a string or an int. A string is raw SQL. An int is a
|
||||
// argument placeholder.
|
||||
type Part interface{}
|
||||
|
||||
type Query struct {
|
||||
Parts []Part
|
||||
}
|
||||
|
||||
func (q *Query) Sanitize(args ...interface{}) (string, error) {
|
||||
argUse := make([]bool, len(args))
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
for _, part := range q.Parts {
|
||||
var str string
|
||||
switch part := part.(type) {
|
||||
case string:
|
||||
str = part
|
||||
case int:
|
||||
argIdx := part - 1
|
||||
if argIdx >= len(args) {
|
||||
return "", errors.Errorf("insufficient arguments")
|
||||
}
|
||||
arg := args[argIdx]
|
||||
switch arg := arg.(type) {
|
||||
case nil:
|
||||
str = "null"
|
||||
case int64:
|
||||
str = strconv.FormatInt(arg, 10)
|
||||
case float64:
|
||||
str = strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
case bool:
|
||||
str = strconv.FormatBool(arg)
|
||||
case []byte:
|
||||
str = QuoteBytes(arg)
|
||||
case string:
|
||||
str = QuoteString(arg)
|
||||
case time.Time:
|
||||
str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||
default:
|
||||
return "", errors.Errorf("invalid arg type: %T", arg)
|
||||
}
|
||||
argUse[argIdx] = true
|
||||
default:
|
||||
return "", errors.Errorf("invalid Part type: %T", part)
|
||||
}
|
||||
buf.WriteString(str)
|
||||
}
|
||||
|
||||
for i, used := range argUse {
|
||||
if !used {
|
||||
return "", errors.Errorf("unused argument: %d", i)
|
||||
}
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func NewQuery(sql string) (*Query, error) {
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
}
|
||||
|
||||
for l.stateFn != nil {
|
||||
l.stateFn = l.stateFn(l)
|
||||
}
|
||||
|
||||
query := &Query{Parts: l.parts}
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func QuoteString(str string) string {
|
||||
return "'" + strings.Replace(str, "'", "''", -1) + "'"
|
||||
}
|
||||
|
||||
func QuoteBytes(buf []byte) string {
|
||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||
}
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
stateFn stateFn
|
||||
parts []Part
|
||||
}
|
||||
|
||||
type stateFn func(*sqlLexer) stateFn
|
||||
|
||||
func rawState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case 'e', 'E':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '\'' {
|
||||
l.pos += width
|
||||
return escapeStringState
|
||||
}
|
||||
case '\'':
|
||||
return singleQuoteState
|
||||
case '"':
|
||||
return doubleQuoteState
|
||||
case '$':
|
||||
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if '0' <= nextRune && nextRune <= '9' {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos-width])
|
||||
}
|
||||
l.start = l.pos
|
||||
return placeholderState
|
||||
}
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func singleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doubleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '"':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '"' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// placeholderState consumes a placeholder value. The $ must have already has
|
||||
// already been consumed. The first rune must be a digit.
|
||||
func placeholderState(l *sqlLexer) stateFn {
|
||||
num := 0
|
||||
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
if '0' <= r && r <= '9' {
|
||||
num *= 10
|
||||
num += int(r - '0')
|
||||
} else {
|
||||
l.parts = append(l.parts, num)
|
||||
l.pos -= width
|
||||
l.start = l.pos
|
||||
return rawState
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func escapeStringState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||
// as necessary. This function is only safe when standard_conforming_strings is
|
||||
// on.
|
||||
func SanitizeSQL(sql string, args ...interface{}) (string, error) {
|
||||
query, err := NewQuery(sql)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return query.Sanitize(args...)
|
||||
}
|
|
@ -0,0 +1,175 @@
|
|||
package sanitize_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/v4/internal/sanitize"
|
||||
)
|
||||
|
||||
func TestNewQuery(t *testing.T) {
|
||||
successTests := []struct {
|
||||
sql string
|
||||
expected sanitize.Query
|
||||
}{
|
||||
{
|
||||
sql: "select 42",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
|
||||
},
|
||||
{
|
||||
sql: "select $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
},
|
||||
{
|
||||
sql: "select 'quoted $42', $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}},
|
||||
},
|
||||
{
|
||||
sql: `select "doubled quoted $42", $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}},
|
||||
},
|
||||
{
|
||||
sql: "select 'foo''bar', $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}},
|
||||
},
|
||||
{
|
||||
sql: `select "foo""bar", $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}},
|
||||
},
|
||||
{
|
||||
sql: "select '''', $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}},
|
||||
},
|
||||
{
|
||||
sql: `select """", $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}},
|
||||
},
|
||||
{
|
||||
sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}},
|
||||
},
|
||||
{
|
||||
sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}},
|
||||
},
|
||||
{
|
||||
sql: `select E'escape string\' $42', $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}},
|
||||
},
|
||||
{
|
||||
sql: `select e'escape string\' $42', $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successTests {
|
||||
query, err := sanitize.NewQuery(tt.sql)
|
||||
if err != nil {
|
||||
t.Errorf("%d. %v", i, err)
|
||||
}
|
||||
|
||||
if len(query.Parts) == len(tt.expected.Parts) {
|
||||
for j := range query.Parts {
|
||||
if query.Parts[j] != tt.expected.Parts[j] {
|
||||
t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuerySanitize(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
query sanitize.Query
|
||||
args []interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
|
||||
args: []interface{}{},
|
||||
expected: `select 42`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `select 42`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{float64(1.23)},
|
||||
expected: `select 1.23`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{true},
|
||||
expected: `select true`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{[]byte{0, 1, 2, 3, 255}},
|
||||
expected: `select '\x00010203ff'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{nil},
|
||||
expected: `select null`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{"foobar"},
|
||||
expected: `select 'foobar'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{"foo'bar"},
|
||||
expected: `select 'foo''bar'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{`foo\'bar`},
|
||||
expected: `select 'foo\''bar'`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
actual, err := tt.query.Sanitize(tt.args...)
|
||||
if err != nil {
|
||||
t.Errorf("%d. %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if tt.expected != actual {
|
||||
t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
errorTests := []struct {
|
||||
query sanitize.Query
|
||||
args []interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `insufficient arguments`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}},
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `unused argument: 0`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{42},
|
||||
expected: `invalid arg type: int`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range errorTests {
|
||||
_, err := tt.query.Sanitize(tt.args...)
|
||||
if err == nil || err.Error() != tt.expected {
|
||||
t.Errorf("%d. expected error %v, got %v", i, tt.expected, err)
|
||||
}
|
||||
}
|
||||
}
|
217
query_test.go
217
query_test.go
|
@ -1305,3 +1305,220 @@ func TestRowsFromResultReader(t *testing.T) {
|
|||
t.Error("Wrong values returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
// Test all supported low-level types
|
||||
|
||||
{
|
||||
expected := int64(42)
|
||||
var actual int64
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::int8",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := float64(1.23)
|
||||
var actual float64
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::float8",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := true
|
||||
var actual bool
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95}
|
||||
var actual []byte
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::bytea",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if bytes.Compare(actual, expected) != 0 {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := "test"
|
||||
var actual string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::text",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
// Test high-level type
|
||||
|
||||
{
|
||||
expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present}
|
||||
actual := expected
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::circle",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
&expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
// Test multiple args in single query
|
||||
|
||||
{
|
||||
expectedInt64 := int64(234423)
|
||||
expectedFloat64 := float64(-0.2312)
|
||||
expectedBool := true
|
||||
expectedBytes := []byte{255, 0, 23, 16, 87, 45, 9, 23, 45, 223}
|
||||
expectedString := "test"
|
||||
var actualInt64 int64
|
||||
var actualFloat64 float64
|
||||
var actualBool bool
|
||||
var actualBytes []byte
|
||||
var actualString string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::int8, $2::float8, $3, $4::bytea, $5::text",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString,
|
||||
).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expectedInt64 != actualInt64 {
|
||||
t.Errorf("expected %v got %v", expectedInt64, actualInt64)
|
||||
}
|
||||
if expectedFloat64 != actualFloat64 {
|
||||
t.Errorf("expected %v got %v", expectedFloat64, actualFloat64)
|
||||
}
|
||||
if expectedBool != actualBool {
|
||||
t.Errorf("expected %v got %v", expectedBool, actualBool)
|
||||
}
|
||||
if bytes.Compare(expectedBytes, actualBytes) != 0 {
|
||||
t.Errorf("expected %v got %v", expectedBytes, actualBytes)
|
||||
}
|
||||
if expectedString != actualString {
|
||||
t.Errorf("expected %v got %v", expectedString, actualString)
|
||||
}
|
||||
}
|
||||
|
||||
// Test dangerous cases
|
||||
|
||||
{
|
||||
expected := "foo';drop table users;"
|
||||
var actual string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "set client_encoding to 'SQL_ASCII'")
|
||||
|
||||
var expected string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
"test",
|
||||
).Scan(&expected)
|
||||
if err == nil {
|
||||
t.Error("expected error when client_encoding not UTF8, but no error occurred")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "set standard_conforming_strings to off")
|
||||
|
||||
var expected string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
pgx.QuerySimpleProtocol(true),
|
||||
`\'; drop table users; --`,
|
||||
).Scan(&expected)
|
||||
if err == nil {
|
||||
t.Error("expected error when standard_conforming_strings is off, but no error occurred")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
|
10
rows.go
10
rows.go
|
@ -86,7 +86,8 @@ type connRows struct {
|
|||
args []interface{}
|
||||
closed bool
|
||||
|
||||
resultReader *pgconn.ResultReader
|
||||
resultReader *pgconn.ResultReader
|
||||
multiResultReader *pgconn.MultiResultReader
|
||||
}
|
||||
|
||||
func (rows *connRows) FieldDescriptions() []pgproto3.FieldDescription {
|
||||
|
@ -107,6 +108,13 @@ func (rows *connRows) Close() {
|
|||
}
|
||||
}
|
||||
|
||||
if rows.multiResultReader != nil {
|
||||
closeErr := rows.multiResultReader.Close()
|
||||
if rows.err == nil {
|
||||
rows.err = closeErr
|
||||
}
|
||||
}
|
||||
|
||||
if rows.logger != nil {
|
||||
if rows.err == nil {
|
||||
if rows.logger.shouldLog(LogLevelInfo) {
|
||||
|
|
|
@ -1054,6 +1054,83 @@ func TestRowsColumnTypes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSimpleQueryLifeCycle(t *testing.T) {
|
||||
config, err := pgx.ParseConfig("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")
|
||||
if err != nil {
|
||||
t.Fatalf("pgx.ParseConnectionString failed: %v", err)
|
||||
}
|
||||
config.PreferSimpleProtocol = true
|
||||
|
||||
db := stdlib.OpenDB(*config)
|
||||
defer closeDB(t, db)
|
||||
|
||||
rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
rowCount := int64(0)
|
||||
|
||||
for rows.Next() {
|
||||
rowCount++
|
||||
var (
|
||||
s string
|
||||
n int64
|
||||
)
|
||||
|
||||
if err := rows.Scan(&s, &n); err != nil {
|
||||
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
if s != "foo" {
|
||||
t.Errorf(`Expected "foo", received "%v"`, s)
|
||||
}
|
||||
|
||||
if n != rowCount {
|
||||
t.Errorf("Expected %d, received %d", rowCount, n)
|
||||
}
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||
}
|
||||
|
||||
if rowCount != 10 {
|
||||
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||
}
|
||||
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
rows, err = db.Query("select 1 where false")
|
||||
if err != nil {
|
||||
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
rowCount = int64(0)
|
||||
|
||||
for rows.Next() {
|
||||
rowCount++
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||
}
|
||||
|
||||
if rowCount != 0 {
|
||||
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||
}
|
||||
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, db)
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/409
|
||||
func TestScanJSONIntoJSONRawMessage(t *testing.T) {
|
||||
db := openDB(t)
|
||||
|
|
Loading…
Reference in New Issue