Remove simple protocol and one round trip query options

It is impossible to guarantee that the a query executed with the simple
protocol will behave the same as with the extended protocol. This is
because the normal pgx path relies on knowing the OID of query
parameters. Without this encoding a value can only be determined by the
value instead of the combination of value and PostgreSQL type. For
example, how should a []int32 be encoded? It might be encoded into a
PostgreSQL int4[] or json.

Removal also simplifies the core query path.

The primary reason for the simple protocol is for servers like PgBouncer
that may not be able to support normal prepared statements. After
further research it appears that issuing a "flush" instead "sync" after
preparing the unnamed statement would allow PgBouncer to work.

The one round trip mode can be better handled with prepared statements.

As a last resort, all original server functionality can still be accessed by
dropping down to PgConn.
pull/483/head
Jack Christensen 2019-04-13 11:39:01 -05:00
parent 5a374c467f
commit c53c9e6eb5
16 changed files with 32 additions and 1021 deletions

11
conn.go
View File

@ -45,17 +45,6 @@ type ConnConfig struct {
Logger Logger
LogLevel LogLevel
CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
// PreferSimpleProtocol disables implicit prepared statement usage. By default
// pgx automatically uses the unnamed prepared statement for Query and
// QueryRow. It also uses a prepared statement when Exec has arguments. This
// can improve performance due to being able to use the binary format. It also
// does not rely on client side parameter sanitization. However, it does incur
// two round-trips per query and may be incompatible proxies such as
// PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be
// used by default. The same functionality can be controlled on a per query
// basis by setting QueryExOptions.SimpleProtocol.
PreferSimpleProtocol bool
}
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.

View File

@ -78,31 +78,6 @@ func TestConnect(t *testing.T) {
}
}
func TestConnectWithPreferSimpleProtocol(t *testing.T) {
t.Parallel()
connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
connConfig.PreferSimpleProtocol = true
conn := mustConnect(t, connConfig)
defer closeConn(t, conn)
// If simple protocol is used we should be able to correctly scan the result
// into a pgtype.Text as the integer will have been encoded in text.
var s pgtype.Text
err := conn.QueryRow(context.Background(), "select $1::int4", 42).Scan(&s)
if err != nil {
t.Fatal(err)
}
if s.Get() != "42" {
t.Fatalf(`expected "42", got %v`, s)
}
ensureConnValid(t, conn)
}
func TestExec(t *testing.T) {
t.Parallel()
@ -285,44 +260,6 @@ func TestExecExtendedProtocol(t *testing.T) {
ensureConnValid(t, conn)
}
func TestExecSimpleProtocol(t *testing.T) {
t.Skip("TODO when with simple protocol supported in connection")
// t.Parallel()
// conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
// defer closeConn(t, conn)
// ctx, cancelFunc := context.WithCancel(context.Background())
// defer cancelFunc()
// commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
// if err != nil {
// t.Fatal(err)
// }
// if string(commandTag) != "CREATE TABLE" {
// t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
// }
// if !conn.LastStmtSent() {
// t.Error("Expected LastStmtSent to return true")
// }
// commandTag, err = conn.ExecEx(
// ctx,
// "insert into foo(name) values($1);",
// &pgx.QueryExOptions{SimpleProtocol: true},
// "bar'; drop table foo;--",
// )
// if err != nil {
// t.Fatal(err)
// }
// if string(commandTag) != "INSERT 0 1" {
// t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
// }
// if !conn.LastStmtSent() {
// t.Error("Expected LastStmtSent to return true")
// }
}
func TestExecExFailureCloseBefore(t *testing.T) {
t.Parallel()

View File

@ -1,237 +0,0 @@
package sanitize
import (
"bytes"
"encoding/hex"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/pkg/errors"
)
// Part is either a string or an int. A string is raw SQL. An int is a
// argument placeholder.
type Part interface{}
type Query struct {
Parts []Part
}
func (q *Query) Sanitize(args ...interface{}) (string, error) {
argUse := make([]bool, len(args))
buf := &bytes.Buffer{}
for _, part := range q.Parts {
var str string
switch part := part.(type) {
case string:
str = part
case int:
argIdx := part - 1
if argIdx >= len(args) {
return "", errors.Errorf("insufficient arguments")
}
arg := args[argIdx]
switch arg := arg.(type) {
case nil:
str = "null"
case int64:
str = strconv.FormatInt(arg, 10)
case float64:
str = strconv.FormatFloat(arg, 'f', -1, 64)
case bool:
str = strconv.FormatBool(arg)
case []byte:
str = QuoteBytes(arg)
case string:
str = QuoteString(arg)
case time.Time:
str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", errors.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true
default:
return "", errors.Errorf("invalid Part type: %T", part)
}
buf.WriteString(str)
}
for i, used := range argUse {
if !used {
return "", errors.Errorf("unused argument: %d", i)
}
}
return buf.String(), nil
}
func NewQuery(sql string) (*Query, error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
}
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
query := &Query{Parts: l.parts}
return query, nil
}
func QuoteString(str string) string {
return "'" + strings.Replace(str, "'", "''", -1) + "'"
}
func QuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
}
type sqlLexer struct {
src string
start int
pos int
stateFn stateFn
parts []Part
}
type stateFn func(*sqlLexer) stateFn
func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case 'e', 'E':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '\'' {
l.pos += width
return escapeStringState
}
case '\'':
return singleQuoteState
case '"':
return doubleQuoteState
case '$':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if '0' <= nextRune && nextRune <= '9' {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}
l.start = l.pos
return placeholderState
}
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func singleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func doubleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '"':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '"' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
// placeholderState consumes a placeholder value. The $ must have already has
// already been consumed. The first rune must be a digit.
func placeholderState(l *sqlLexer) stateFn {
num := 0
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
if '0' <= r && r <= '9' {
num *= 10
num += int(r - '0')
} else {
l.parts = append(l.parts, num)
l.pos -= width
l.start = l.pos
return rawState
}
}
}
func escapeStringState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func SanitizeSQL(sql string, args ...interface{}) (string, error) {
query, err := NewQuery(sql)
if err != nil {
return "", err
}
return query.Sanitize(args...)
}

View File

@ -1,175 +0,0 @@
package sanitize_test
import (
"testing"
"github.com/jackc/pgx/internal/sanitize"
)
func TestNewQuery(t *testing.T) {
successTests := []struct {
sql string
expected sanitize.Query
}{
{
sql: "select 42",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
},
{
sql: "select $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
},
{
sql: "select 'quoted $42', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}},
},
{
sql: `select "doubled quoted $42", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}},
},
{
sql: "select 'foo''bar', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}},
},
{
sql: `select "foo""bar", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}},
},
{
sql: "select '''', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}},
},
{
sql: `select """", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}},
},
{
sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11",
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}},
},
{
sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}},
},
{
sql: `select E'escape string\' $42', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}},
},
{
sql: `select e'escape string\' $42', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
},
}
for i, tt := range successTests {
query, err := sanitize.NewQuery(tt.sql)
if err != nil {
t.Errorf("%d. %v", i, err)
}
if len(query.Parts) == len(tt.expected.Parts) {
for j := range query.Parts {
if query.Parts[j] != tt.expected.Parts[j] {
t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j])
}
}
} else {
t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts)
}
}
}
func TestQuerySanitize(t *testing.T) {
successfulTests := []struct {
query sanitize.Query
args []interface{}
expected string
}{
{
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
args: []interface{}{},
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{int64(42)},
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{float64(1.23)},
expected: `select 1.23`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{true},
expected: `select true`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{[]byte{0, 1, 2, 3, 255}},
expected: `select '\x00010203ff'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{nil},
expected: `select null`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{"foobar"},
expected: `select 'foobar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{"foo'bar"},
expected: `select 'foo''bar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{`foo\'bar`},
expected: `select 'foo\''bar'`,
},
}
for i, tt := range successfulTests {
actual, err := tt.query.Sanitize(tt.args...)
if err != nil {
t.Errorf("%d. %v", i, err)
continue
}
if tt.expected != actual {
t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual)
}
}
errorTests := []struct {
query sanitize.Query
args []interface{}
expected string
}{
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
args: []interface{}{int64(42)},
expected: `insufficient arguments`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}},
args: []interface{}{int64(42)},
expected: `unused argument: 0`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{42},
expected: `invalid arg type: int`,
},
}
for i, tt := range errorTests {
_, err := tt.query.Sanitize(tt.args...)
if err == nil || err.Error() != tt.expected {
t.Errorf("%d. expected error %v, got %v", i, tt.expected, err)
}
}
}

View File

@ -20,9 +20,6 @@ func TestCIDTranscode(t *testing.T) {
testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
// No direct conversion from int to cid, convert through text
testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc)
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
}

View File

@ -98,7 +98,6 @@ func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface
func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
}
@ -150,35 +149,6 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []
}
}
func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
conn := MustConnectPgx(t)
defer MustCloseContext(t, conn)
for i, v := range values {
// Derefence value if it is a pointer
derefV := v
refVal := reflect.ValueOf(v)
if refVal.Kind() == reflect.Ptr {
derefV = refVal.Elem().Interface()
}
result := reflect.New(reflect.TypeOf(derefV))
err := conn.QueryRow(
context.Background(),
fmt.Sprintf("select ($1)::%s", pgTypeName),
&pgx.QueryExOptions{SimpleProtocol: true},
v,
).Scan(result.Interface())
if err != nil {
t.Errorf("Simple protocol %d: %v", i, err)
}
if !eqFunc(result.Elem().Interface(), derefV) {
t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface())
}
}
}
func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
conn := MustConnectDatabaseSQL(t, driverName)
defer MustClose(t, conn)

View File

@ -20,9 +20,6 @@ func TestXIDTranscode(t *testing.T) {
testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
// No direct conversion from int to xid, convert through text
testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc)
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
}

View File

@ -37,7 +37,7 @@ func testExec(t *testing.T, db execer) {
}
type queryer interface {
Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error)
Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error)
}
func testQuery(t *testing.T, db queryer) {
@ -59,7 +59,7 @@ func testQuery(t *testing.T, db queryer) {
}
type queryRower interface {
QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row
QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
}
func testQueryRow(t *testing.T, db queryRower) {

View File

@ -53,12 +53,12 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (
return c.Conn().Exec(ctx, sql, arguments...)
}
func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) {
return c.Conn().Query(ctx, sql, optionsAndArgs...)
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
return c.Conn().Query(ctx, sql, args...)
}
func (c *Conn) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row {
return c.Conn().QueryRow(ctx, sql, optionsAndArgs...)
func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
return c.Conn().QueryRow(ctx, sql, args...)
}
func (c *Conn) Begin() (*pgx.Tx, error) {

View File

@ -68,13 +68,13 @@ func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) (
return c.Exec(ctx, sql, arguments...)
}
func (p *Pool) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) {
func (p *Pool) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
c, err := p.Acquire(ctx)
if err != nil {
return errRows{err: err}, err
}
rows, err := c.Query(ctx, sql, optionsAndArgs...)
rows, err := c.Query(ctx, sql, args...)
if err != nil {
c.Release()
return errRows{err: err}, err
@ -83,13 +83,13 @@ func (p *Pool) Query(ctx context.Context, sql string, optionsAndArgs ...interfac
return &poolRows{r: rows, c: c}, nil
}
func (p *Pool) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row {
func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
c, err := p.Acquire(ctx)
if err != nil {
return errRow{err: err}
}
row := c.QueryRow(ctx, sql, optionsAndArgs...)
row := c.QueryRow(ctx, sql, args...)
return &poolRow{r: row, c: c}
}

View File

@ -38,10 +38,10 @@ func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (p
return tx.c.Exec(ctx, sql, arguments...)
}
func (tx *Tx) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) {
return tx.c.Query(ctx, sql, optionsAndArgs...)
func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
return tx.c.Query(ctx, sql, args...)
}
func (tx *Tx) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row {
return tx.c.QueryRow(ctx, sql, optionsAndArgs...)
func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
return tx.c.QueryRow(ctx, sql, args...)
}

View File

@ -9,7 +9,6 @@ import (
"github.com/pkg/errors"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/internal/sanitize"
"github.com/jackc/pgx/pgtype"
)
@ -288,31 +287,12 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows {
return r
}
type QueryExOptions struct {
// When ParameterOIDs are present and the query is not a prepared statement,
// then ParameterOIDs and ResultFormatCodes will be used to avoid an extra
// network round-trip.
ParameterOIDs []pgtype.OID
ResultFormatCodes []int16
SimpleProtocol bool
}
// Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is
// allowed to ignore the error returned from Query and handle it in Rows.
func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (Rows, error) {
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
c.lastStmtSent = false
// rows = c.getRows(sql, args)
var options *QueryExOptions
args := optionsAndArgs
if len(optionsAndArgs) > 0 {
if o, ok := optionsAndArgs[0].(*QueryExOptions); ok {
options = o
args = optionsAndArgs[1:]
}
}
rows := &connRows{
conn: c,
startTime: time.Now(),
@ -332,27 +312,6 @@ func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interfac
// return rows, rows.err
// }
var err error
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
sql, err = c.sanitizeForSimpleQuery(sql, args...)
if err != nil {
rows.fatal(err)
return rows, err
}
c.lastStmtSent = true
rows.multiResultReader = c.pgConn.Exec(ctx, sql)
if rows.multiResultReader.NextResult() {
rows.resultReader = rows.multiResultReader.ResultReader()
} else {
err = rows.multiResultReader.Close()
rows.fatal(err)
return rows, err
}
return rows, nil
}
// if options != nil && len(options.ParameterOIDs) > 0 {
// buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args)
@ -427,6 +386,7 @@ func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interfac
}
rows.sql = ps.SQL
var err error
args, err = convertDriverValuers(args)
if err != nil {
rows.fatal(err)
@ -461,31 +421,10 @@ func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interfac
return rows, rows.err
}
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) {
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
}
if c.pgConn.ParameterStatus("client_encoding") != "UTF8" {
return "", errors.New("simple protocol queries must be run with client_encoding=UTF8")
}
var err error
valueArgs := make([]interface{}, len(args))
for i, a := range args {
valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a)
if err != nil {
return "", err
}
}
return sanitize.SanitizeSQL(sql, valueArgs...)
}
// QueryRow is a convenience wrapper over Query. Any error that occurs while
// querying is deferred until calling Scan on the returned Row. That Row will
// error with ErrNoRows if no rows are returned.
func (c *Conn) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) Row {
rows, _ := c.Query(ctx, sql, optionsAndArgs...)
func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row {
rows, _ := c.Query(ctx, sql, args...)
return (*connRow)(rows.(*connRows))
}

View File

@ -908,44 +908,6 @@ func TestQueryRowErrors(t *testing.T) {
}
}
func TestQueryRowExErrorsWrongParameterOIDs(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
sql := `
with t as (
select 1::int8 as some_int, 'foo'::text as some_text
)
select some_int from t where some_text = $1`
paramOIDs := []pgtype.OID{pgtype.TextArrayOID}
queryArgs := []interface{}{"bar"}
queryOptions := &pgx.QueryExOptions{
ParameterOIDs: paramOIDs,
ResultFormatCodes: []int16{pgx.BinaryFormatCode},
}
optionsAndArgs := append([]interface{}{queryOptions}, queryArgs...)
expectedErr := "operator does not exist: text = text[] (SQLSTATE 42883)"
var result int64
err := conn.QueryRow(
context.Background(),
sql,
optionsAndArgs...,
).Scan(&result)
if err == nil {
t.Errorf("Unexpected success (sql -> %v, paramOIDs -> %v, queryArgs -> %v)", sql, paramOIDs, queryArgs)
}
if err != nil && !strings.Contains(err.Error(), expectedErr) {
t.Errorf("Expected error to contain %s, but got %v (sql -> %v, paramOIDs -> %v, queryArgs -> %v)",
expectedErr, err, sql, paramOIDs, queryArgs)
}
ensureConnValid(t, conn)
}
func TestQueryRowNoResults(t *testing.T) {
t.Parallel()
@ -1302,273 +1264,6 @@ func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
ensureConnValid(t, conn)
}
func TestConnQueryRowSingleRoundTrip(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
var result int32
err := conn.QueryRow(
context.Background(),
"select $1 + $2",
&pgx.QueryExOptions{
ParameterOIDs: []pgtype.OID{pgtype.Int4OID, pgtype.Int4OID},
ResultFormatCodes: []int16{pgx.BinaryFormatCode},
},
1, 2,
).Scan(&result)
if err != nil {
t.Fatal(err)
}
if result != 3 {
t.Fatalf("result => %d, want %d", result, 3)
}
ensureConnValid(t, conn)
}
func TestConnSimpleProtocol(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
// Test all supported low-level types
{
expected := int64(42)
var actual int64
err := conn.QueryRow(
context.Background(),
"select $1::int8",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
}
{
expected := float64(1.23)
var actual float64
err := conn.QueryRow(
context.Background(),
"select $1::float8",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
}
{
expected := true
var actual bool
err := conn.QueryRow(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
}
{
expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95}
var actual []byte
err := conn.QueryRow(
context.Background(),
"select $1::bytea",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if bytes.Compare(actual, expected) != 0 {
t.Errorf("expected %v got %v", expected, actual)
}
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
}
{
expected := "test"
var actual string
err := conn.QueryRow(
context.Background(),
"select $1::text",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
}
// Test high-level type
{
expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present}
actual := expected
err := conn.QueryRow(
context.Background(),
"select $1::circle",
&pgx.QueryExOptions{SimpleProtocol: true},
&expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
}
// Test multiple args in single query
{
expectedInt64 := int64(234423)
expectedFloat64 := float64(-0.2312)
expectedBool := true
expectedBytes := []byte{255, 0, 23, 16, 87, 45, 9, 23, 45, 223}
expectedString := "test"
var actualInt64 int64
var actualFloat64 float64
var actualBool bool
var actualBytes []byte
var actualString string
err := conn.QueryRow(
context.Background(),
"select $1::int8, $2::float8, $3, $4::bytea, $5::text",
&pgx.QueryExOptions{SimpleProtocol: true},
expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString,
).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString)
if err != nil {
t.Error(err)
}
if expectedInt64 != actualInt64 {
t.Errorf("expected %v got %v", expectedInt64, actualInt64)
}
if expectedFloat64 != actualFloat64 {
t.Errorf("expected %v got %v", expectedFloat64, actualFloat64)
}
if expectedBool != actualBool {
t.Errorf("expected %v got %v", expectedBool, actualBool)
}
if bytes.Compare(expectedBytes, actualBytes) != 0 {
t.Errorf("expected %v got %v", expectedBytes, actualBytes)
}
if expectedString != actualString {
t.Errorf("expected %v got %v", expectedString, actualString)
}
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
}
// Test dangerous cases
{
expected := "foo';drop table users;"
var actual string
err := conn.QueryRow(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
if !conn.LastStmtSent() {
t.Error("Expected LastStmtSent to return true")
}
}
ensureConnValid(t, conn)
}
func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, "set client_encoding to 'SQL_ASCII'")
var expected string
err := conn.QueryRow(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
"test",
).Scan(&expected)
if err == nil {
t.Error("expected error when client_encoding not UTF8, but no error occurred")
}
ensureConnValid(t, conn)
}
func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, "set standard_conforming_strings to off")
var expected string
err := conn.QueryRow(
context.Background(),
"select $1",
&pgx.QueryExOptions{SimpleProtocol: true},
`\'; drop table users; --`,
).Scan(&expected)
if err == nil {
t.Error("expected error when standard_conforming_strings is off, but no error occurred")
}
ensureConnValid(t, conn)
}
func TestQueryCloseBefore(t *testing.T) {
t.Parallel()

View File

@ -239,40 +239,22 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
return nil, driver.ErrBadConn
}
var rows pgx.Rows
// TODO - remove hack that creates a new prepared statement for every query -- put in place because of problem preparing empty statement name
psname := fmt.Sprintf("stdlibpx%v", &argsV)
if !c.connConfig.PreferSimpleProtocol {
// TODO - remove hack that creates a new prepared statement for every query -- put in place because of problem preparing empty statement name
psname := fmt.Sprintf("stdlibpx%v", &argsV)
ps, err := c.conn.PrepareEx(ctx, psname, query, nil)
if err != nil {
// since PrepareEx failed, we didn't actually get to send the values, so
// we can safely retry
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPreparedContext(ctx, psname, argsV)
}
rows, err := c.conn.Query(ctx, query, namedValueToInterface(argsV)...)
ps, err := c.conn.PrepareEx(ctx, psname, query, nil)
if err != nil {
// if we got a network error before we had a chance to send the query, retry
if !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
// since PrepareEx failed, we didn't actually get to send the values, so
// we can safely retry
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
return nil, err
}
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
more := rows.Next()
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPreparedContext(ctx, psname, argsV)
}
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {

View File

@ -1054,89 +1054,6 @@ func TestRowsColumnTypes(t *testing.T) {
}
}
func TestSimpleQueryLifeCycle(t *testing.T) {
// TODO - need to use new method of establishing connection with pgx specific configuration
// driverConfig := stdlib.DriverConfig{
// ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true},
// }
// stdlib.RegisterDriverConfig(&driverConfig)
// defer stdlib.UnregisterDriverConfig(&driverConfig)
// db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
// if err != nil {
// t.Fatalf("sql.Open failed: %v", err)
// }
// defer closeDB(t, db)
// rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
// if err != nil {
// t.Fatalf("stmt.Query unexpectedly failed: %v", err)
// }
// rowCount := int64(0)
// for rows.Next() {
// rowCount++
// var (
// s string
// n int64
// )
// if err := rows.Scan(&s, &n); err != nil {
// t.Fatalf("rows.Scan unexpectedly failed: %v", err)
// }
// if s != "foo" {
// t.Errorf(`Expected "foo", received "%v"`, s)
// }
// if n != rowCount {
// t.Errorf("Expected %d, received %d", rowCount, n)
// }
// }
// if err = rows.Err(); err != nil {
// t.Fatalf("rows.Err unexpectedly is: %v", err)
// }
// if rowCount != 10 {
// t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
// }
// err = rows.Close()
// if err != nil {
// t.Fatalf("rows.Close unexpectedly failed: %v", err)
// }
// rows, err = db.Query("select 1 where false")
// if err != nil {
// t.Fatalf("stmt.Query unexpectedly failed: %v", err)
// }
// rowCount = int64(0)
// for rows.Next() {
// rowCount++
// }
// if err = rows.Err(); err != nil {
// t.Fatalf("rows.Err unexpectedly is: %v", err)
// }
// if rowCount != 0 {
// t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
// }
// err = rows.Close()
// if err != nil {
// t.Fatalf("rows.Close unexpectedly failed: %v", err)
// }
// ensureConnValid(t, db)
}
// https://github.com/jackc/pgx/issues/409
func TestScanJSONIntoJSONRawMessage(t *testing.T) {
db := openDB(t)

8
tx.go
View File

@ -172,19 +172,19 @@ func (tx *Tx) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOp
}
// Query delegates to the underlying *Conn
func (tx *Tx) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (Rows, error) {
func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
if tx.status != TxStatusInProgress {
// Because checking for errors can be deferred to the *Rows, build one with the error
err := ErrTxClosed
return &connRows{closed: true, err: err}, err
}
return tx.conn.Query(ctx, sql, optionsAndArgs...)
return tx.conn.Query(ctx, sql, args...)
}
// QueryRow delegates to the underlying *Conn
func (tx *Tx) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) Row {
rows, _ := tx.Query(ctx, sql, optionsAndArgs...)
func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row {
rows, _ := tx.Query(ctx, sql, args...)
return (*connRow)(rows.(*connRows))
}