From c53c9e6eb5634d800691af2aa3c33dde97e8032e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Apr 2019 11:39:01 -0500 Subject: [PATCH] Remove simple protocol and one round trip query options It is impossible to guarantee that the a query executed with the simple protocol will behave the same as with the extended protocol. This is because the normal pgx path relies on knowing the OID of query parameters. Without this encoding a value can only be determined by the value instead of the combination of value and PostgreSQL type. For example, how should a []int32 be encoded? It might be encoded into a PostgreSQL int4[] or json. Removal also simplifies the core query path. The primary reason for the simple protocol is for servers like PgBouncer that may not be able to support normal prepared statements. After further research it appears that issuing a "flush" instead "sync" after preparing the unnamed statement would allow PgBouncer to work. The one round trip mode can be better handled with prepared statements. As a last resort, all original server functionality can still be accessed by dropping down to PgConn. --- conn.go | 11 -- conn_test.go | 63 ------ internal/sanitize/sanitize.go | 237 ---------------------- internal/sanitize/sanitize_test.go | 175 ----------------- pgtype/cid_test.go | 3 - pgtype/testutil/testutil.go | 30 --- pgtype/xid_test.go | 3 - pool/common_test.go | 4 +- pool/conn.go | 8 +- pool/pool.go | 8 +- pool/tx.go | 8 +- query.go | 69 +------ query_test.go | 305 ----------------------------- stdlib/sql.go | 38 +--- stdlib/sql_test.go | 83 -------- tx.go | 8 +- 16 files changed, 32 insertions(+), 1021 deletions(-) delete mode 100644 internal/sanitize/sanitize.go delete mode 100644 internal/sanitize/sanitize_test.go diff --git a/conn.go b/conn.go index 36503c18..fdd649f8 100644 --- a/conn.go +++ b/conn.go @@ -45,17 +45,6 @@ type ConnConfig struct { Logger Logger LogLevel LogLevel CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc. - - // PreferSimpleProtocol disables implicit prepared statement usage. By default - // pgx automatically uses the unnamed prepared statement for Query and - // QueryRow. It also uses a prepared statement when Exec has arguments. This - // can improve performance due to being able to use the binary format. It also - // does not rely on client side parameter sanitization. However, it does incur - // two round-trips per query and may be incompatible proxies such as - // PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be - // used by default. The same functionality can be controlled on a per query - // basis by setting QueryExOptions.SimpleProtocol. - PreferSimpleProtocol bool } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. diff --git a/conn_test.go b/conn_test.go index 0febefd0..a4b89cec 100644 --- a/conn_test.go +++ b/conn_test.go @@ -78,31 +78,6 @@ func TestConnect(t *testing.T) { } } -func TestConnectWithPreferSimpleProtocol(t *testing.T) { - t.Parallel() - - connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE")) - connConfig.PreferSimpleProtocol = true - - conn := mustConnect(t, connConfig) - defer closeConn(t, conn) - - // If simple protocol is used we should be able to correctly scan the result - // into a pgtype.Text as the integer will have been encoded in text. - - var s pgtype.Text - err := conn.QueryRow(context.Background(), "select $1::int4", 42).Scan(&s) - if err != nil { - t.Fatal(err) - } - - if s.Get() != "42" { - t.Fatalf(`expected "42", got %v`, s) - } - - ensureConnValid(t, conn) -} - func TestExec(t *testing.T) { t.Parallel() @@ -285,44 +260,6 @@ func TestExecExtendedProtocol(t *testing.T) { ensureConnValid(t, conn) } -func TestExecSimpleProtocol(t *testing.T) { - t.Skip("TODO when with simple protocol supported in connection") - // t.Parallel() - - // conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - // defer closeConn(t, conn) - - // ctx, cancelFunc := context.WithCancel(context.Background()) - // defer cancelFunc() - - // commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil) - // if err != nil { - // t.Fatal(err) - // } - // if string(commandTag) != "CREATE TABLE" { - // t.Fatalf("Unexpected results from ExecEx: %v", commandTag) - // } - // if !conn.LastStmtSent() { - // t.Error("Expected LastStmtSent to return true") - // } - - // commandTag, err = conn.ExecEx( - // ctx, - // "insert into foo(name) values($1);", - // &pgx.QueryExOptions{SimpleProtocol: true}, - // "bar'; drop table foo;--", - // ) - // if err != nil { - // t.Fatal(err) - // } - // if string(commandTag) != "INSERT 0 1" { - // t.Fatalf("Unexpected results from ExecEx: %v", commandTag) - // } - // if !conn.LastStmtSent() { - // t.Error("Expected LastStmtSent to return true") - // } -} - func TestExecExFailureCloseBefore(t *testing.T) { t.Parallel() diff --git a/internal/sanitize/sanitize.go b/internal/sanitize/sanitize.go deleted file mode 100644 index 53543b89..00000000 --- a/internal/sanitize/sanitize.go +++ /dev/null @@ -1,237 +0,0 @@ -package sanitize - -import ( - "bytes" - "encoding/hex" - "strconv" - "strings" - "time" - "unicode/utf8" - - "github.com/pkg/errors" -) - -// Part is either a string or an int. A string is raw SQL. An int is a -// argument placeholder. -type Part interface{} - -type Query struct { - Parts []Part -} - -func (q *Query) Sanitize(args ...interface{}) (string, error) { - argUse := make([]bool, len(args)) - buf := &bytes.Buffer{} - - for _, part := range q.Parts { - var str string - switch part := part.(type) { - case string: - str = part - case int: - argIdx := part - 1 - if argIdx >= len(args) { - return "", errors.Errorf("insufficient arguments") - } - arg := args[argIdx] - switch arg := arg.(type) { - case nil: - str = "null" - case int64: - str = strconv.FormatInt(arg, 10) - case float64: - str = strconv.FormatFloat(arg, 'f', -1, 64) - case bool: - str = strconv.FormatBool(arg) - case []byte: - str = QuoteBytes(arg) - case string: - str = QuoteString(arg) - case time.Time: - str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'") - default: - return "", errors.Errorf("invalid arg type: %T", arg) - } - argUse[argIdx] = true - default: - return "", errors.Errorf("invalid Part type: %T", part) - } - buf.WriteString(str) - } - - for i, used := range argUse { - if !used { - return "", errors.Errorf("unused argument: %d", i) - } - } - return buf.String(), nil -} - -func NewQuery(sql string) (*Query, error) { - l := &sqlLexer{ - src: sql, - stateFn: rawState, - } - - for l.stateFn != nil { - l.stateFn = l.stateFn(l) - } - - query := &Query{Parts: l.parts} - - return query, nil -} - -func QuoteString(str string) string { - return "'" + strings.Replace(str, "'", "''", -1) + "'" -} - -func QuoteBytes(buf []byte) string { - return `'\x` + hex.EncodeToString(buf) + "'" -} - -type sqlLexer struct { - src string - start int - pos int - stateFn stateFn - parts []Part -} - -type stateFn func(*sqlLexer) stateFn - -func rawState(l *sqlLexer) stateFn { - for { - r, width := utf8.DecodeRuneInString(l.src[l.pos:]) - l.pos += width - - switch r { - case 'e', 'E': - nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) - if nextRune == '\'' { - l.pos += width - return escapeStringState - } - case '\'': - return singleQuoteState - case '"': - return doubleQuoteState - case '$': - nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:]) - if '0' <= nextRune && nextRune <= '9' { - if l.pos-l.start > 0 { - l.parts = append(l.parts, l.src[l.start:l.pos-width]) - } - l.start = l.pos - return placeholderState - } - case utf8.RuneError: - if l.pos-l.start > 0 { - l.parts = append(l.parts, l.src[l.start:l.pos]) - l.start = l.pos - } - return nil - } - } -} - -func singleQuoteState(l *sqlLexer) stateFn { - for { - r, width := utf8.DecodeRuneInString(l.src[l.pos:]) - l.pos += width - - switch r { - case '\'': - nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) - if nextRune != '\'' { - return rawState - } - l.pos += width - case utf8.RuneError: - if l.pos-l.start > 0 { - l.parts = append(l.parts, l.src[l.start:l.pos]) - l.start = l.pos - } - return nil - } - } -} - -func doubleQuoteState(l *sqlLexer) stateFn { - for { - r, width := utf8.DecodeRuneInString(l.src[l.pos:]) - l.pos += width - - switch r { - case '"': - nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) - if nextRune != '"' { - return rawState - } - l.pos += width - case utf8.RuneError: - if l.pos-l.start > 0 { - l.parts = append(l.parts, l.src[l.start:l.pos]) - l.start = l.pos - } - return nil - } - } -} - -// placeholderState consumes a placeholder value. The $ must have already has -// already been consumed. The first rune must be a digit. -func placeholderState(l *sqlLexer) stateFn { - num := 0 - - for { - r, width := utf8.DecodeRuneInString(l.src[l.pos:]) - l.pos += width - - if '0' <= r && r <= '9' { - num *= 10 - num += int(r - '0') - } else { - l.parts = append(l.parts, num) - l.pos -= width - l.start = l.pos - return rawState - } - } -} - -func escapeStringState(l *sqlLexer) stateFn { - for { - r, width := utf8.DecodeRuneInString(l.src[l.pos:]) - l.pos += width - - switch r { - case '\\': - _, width = utf8.DecodeRuneInString(l.src[l.pos:]) - l.pos += width - case '\'': - nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:]) - if nextRune != '\'' { - return rawState - } - l.pos += width - case utf8.RuneError: - if l.pos-l.start > 0 { - l.parts = append(l.parts, l.src[l.start:l.pos]) - l.start = l.pos - } - return nil - } - } -} - -// SanitizeSQL replaces placeholder values with args. It quotes and escapes args -// as necessary. This function is only safe when standard_conforming_strings is -// on. -func SanitizeSQL(sql string, args ...interface{}) (string, error) { - query, err := NewQuery(sql) - if err != nil { - return "", err - } - return query.Sanitize(args...) -} diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go deleted file mode 100644 index 9597840e..00000000 --- a/internal/sanitize/sanitize_test.go +++ /dev/null @@ -1,175 +0,0 @@ -package sanitize_test - -import ( - "testing" - - "github.com/jackc/pgx/internal/sanitize" -) - -func TestNewQuery(t *testing.T) { - successTests := []struct { - sql string - expected sanitize.Query - }{ - { - sql: "select 42", - expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, - }, - { - sql: "select $1", - expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - }, - { - sql: "select 'quoted $42', $1", - expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}}, - }, - { - sql: `select "doubled quoted $42", $1`, - expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}}, - }, - { - sql: "select 'foo''bar', $1", - expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}}, - }, - { - sql: `select "foo""bar", $1`, - expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}}, - }, - { - sql: "select '''', $1", - expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}}, - }, - { - sql: `select """", $1`, - expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}}, - }, - { - sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11", - expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}}, - }, - { - sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`, - expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}}, - }, - { - sql: `select E'escape string\' $42', $1`, - expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}}, - }, - { - sql: `select e'escape string\' $42', $1`, - expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}}, - }, - } - - for i, tt := range successTests { - query, err := sanitize.NewQuery(tt.sql) - if err != nil { - t.Errorf("%d. %v", i, err) - } - - if len(query.Parts) == len(tt.expected.Parts) { - for j := range query.Parts { - if query.Parts[j] != tt.expected.Parts[j] { - t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j]) - } - } - } else { - t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts) - } - } -} - -func TestQuerySanitize(t *testing.T) { - successfulTests := []struct { - query sanitize.Query - args []interface{} - expected string - }{ - { - query: sanitize.Query{Parts: []sanitize.Part{"select 42"}}, - args: []interface{}{}, - expected: `select 42`, - }, - { - query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{int64(42)}, - expected: `select 42`, - }, - { - query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{float64(1.23)}, - expected: `select 1.23`, - }, - { - query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{true}, - expected: `select true`, - }, - { - query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{[]byte{0, 1, 2, 3, 255}}, - expected: `select '\x00010203ff'`, - }, - { - query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{nil}, - expected: `select null`, - }, - { - query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{"foobar"}, - expected: `select 'foobar'`, - }, - { - query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{"foo'bar"}, - expected: `select 'foo''bar'`, - }, - { - query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{`foo\'bar`}, - expected: `select 'foo\''bar'`, - }, - } - - for i, tt := range successfulTests { - actual, err := tt.query.Sanitize(tt.args...) - if err != nil { - t.Errorf("%d. %v", i, err) - continue - } - - if tt.expected != actual { - t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual) - } - } - - errorTests := []struct { - query sanitize.Query - args []interface{} - expected string - }{ - { - query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}}, - args: []interface{}{int64(42)}, - expected: `insufficient arguments`, - }, - { - query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}}, - args: []interface{}{int64(42)}, - expected: `unused argument: 0`, - }, - { - query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}}, - args: []interface{}{42}, - expected: `invalid arg type: int`, - }, - } - - for i, tt := range errorTests { - _, err := tt.query.Sanitize(tt.args...) - if err == nil || err.Error() != tt.expected { - t.Errorf("%d. expected error %v, got %v", i, tt.expected, err) - } - } -} diff --git a/pgtype/cid_test.go b/pgtype/cid_test.go index 0dfc56d4..924e4cf3 100644 --- a/pgtype/cid_test.go +++ b/pgtype/cid_test.go @@ -20,9 +20,6 @@ func TestCIDTranscode(t *testing.T) { testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - // No direct conversion from int to cid, convert through text - testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) } diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go index 6ea3a69e..462549a7 100644 --- a/pgtype/testutil/testutil.go +++ b/pgtype/testutil/testutil.go @@ -98,7 +98,6 @@ func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) } @@ -150,35 +149,6 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] } } -func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { - conn := MustConnectPgx(t) - defer MustCloseContext(t, conn) - - for i, v := range values { - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } - - result := reflect.New(reflect.TypeOf(derefV)) - err := conn.QueryRow( - context.Background(), - fmt.Sprintf("select ($1)::%s", pgTypeName), - &pgx.QueryExOptions{SimpleProtocol: true}, - v, - ).Scan(result.Interface()) - if err != nil { - t.Errorf("Simple protocol %d: %v", i, err) - } - - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface()) - } - } -} - func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) { conn := MustConnectDatabaseSQL(t, driverName) defer MustClose(t, conn) diff --git a/pgtype/xid_test.go b/pgtype/xid_test.go index d0f3f0ab..594d1214 100644 --- a/pgtype/xid_test.go +++ b/pgtype/xid_test.go @@ -20,9 +20,6 @@ func TestXIDTranscode(t *testing.T) { testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc) - // No direct conversion from int to xid, convert through text - testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc) - for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} { testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc) } diff --git a/pool/common_test.go b/pool/common_test.go index e53bea8b..9e7ab947 100644 --- a/pool/common_test.go +++ b/pool/common_test.go @@ -37,7 +37,7 @@ func testExec(t *testing.T, db execer) { } type queryer interface { - Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) + Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) } func testQuery(t *testing.T, db queryer) { @@ -59,7 +59,7 @@ func testQuery(t *testing.T, db queryer) { } type queryRower interface { - QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row + QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row } func testQueryRow(t *testing.T, db queryRower) { diff --git a/pool/conn.go b/pool/conn.go index 86dc9507..7b11c699 100644 --- a/pool/conn.go +++ b/pool/conn.go @@ -53,12 +53,12 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) ( return c.Conn().Exec(ctx, sql, arguments...) } -func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) { - return c.Conn().Query(ctx, sql, optionsAndArgs...) +func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { + return c.Conn().Query(ctx, sql, args...) } -func (c *Conn) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row { - return c.Conn().QueryRow(ctx, sql, optionsAndArgs...) +func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { + return c.Conn().QueryRow(ctx, sql, args...) } func (c *Conn) Begin() (*pgx.Tx, error) { diff --git a/pool/pool.go b/pool/pool.go index 11401de8..3be32774 100644 --- a/pool/pool.go +++ b/pool/pool.go @@ -68,13 +68,13 @@ func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) ( return c.Exec(ctx, sql, arguments...) } -func (p *Pool) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) { +func (p *Pool) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { c, err := p.Acquire(ctx) if err != nil { return errRows{err: err}, err } - rows, err := c.Query(ctx, sql, optionsAndArgs...) + rows, err := c.Query(ctx, sql, args...) if err != nil { c.Release() return errRows{err: err}, err @@ -83,13 +83,13 @@ func (p *Pool) Query(ctx context.Context, sql string, optionsAndArgs ...interfac return &poolRows{r: rows, c: c}, nil } -func (p *Pool) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row { +func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { c, err := p.Acquire(ctx) if err != nil { return errRow{err: err} } - row := c.QueryRow(ctx, sql, optionsAndArgs...) + row := c.QueryRow(ctx, sql, args...) return &poolRow{r: row, c: c} } diff --git a/pool/tx.go b/pool/tx.go index 4ab1c2f9..9f7b231a 100644 --- a/pool/tx.go +++ b/pool/tx.go @@ -38,10 +38,10 @@ func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (p return tx.c.Exec(ctx, sql, arguments...) } -func (tx *Tx) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) { - return tx.c.Query(ctx, sql, optionsAndArgs...) +func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { + return tx.c.Query(ctx, sql, args...) } -func (tx *Tx) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row { - return tx.c.QueryRow(ctx, sql, optionsAndArgs...) +func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { + return tx.c.QueryRow(ctx, sql, args...) } diff --git a/query.go b/query.go index 5cb503ba..c6c8132f 100644 --- a/query.go +++ b/query.go @@ -9,7 +9,6 @@ import ( "github.com/pkg/errors" "github.com/jackc/pgconn" - "github.com/jackc/pgx/internal/sanitize" "github.com/jackc/pgx/pgtype" ) @@ -288,31 +287,12 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows { return r } -type QueryExOptions struct { - // When ParameterOIDs are present and the query is not a prepared statement, - // then ParameterOIDs and ResultFormatCodes will be used to avoid an extra - // network round-trip. - ParameterOIDs []pgtype.OID - ResultFormatCodes []int16 - - SimpleProtocol bool -} - // Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is // allowed to ignore the error returned from Query and handle it in Rows. -func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (Rows, error) { +func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { c.lastStmtSent = false // rows = c.getRows(sql, args) - var options *QueryExOptions - args := optionsAndArgs - if len(optionsAndArgs) > 0 { - if o, ok := optionsAndArgs[0].(*QueryExOptions); ok { - options = o - args = optionsAndArgs[1:] - } - } - rows := &connRows{ conn: c, startTime: time.Now(), @@ -332,27 +312,6 @@ func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interfac // return rows, rows.err // } - var err error - if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) { - sql, err = c.sanitizeForSimpleQuery(sql, args...) - if err != nil { - rows.fatal(err) - return rows, err - } - - c.lastStmtSent = true - rows.multiResultReader = c.pgConn.Exec(ctx, sql) - if rows.multiResultReader.NextResult() { - rows.resultReader = rows.multiResultReader.ResultReader() - } else { - err = rows.multiResultReader.Close() - rows.fatal(err) - return rows, err - } - - return rows, nil - } - // if options != nil && len(options.ParameterOIDs) > 0 { // buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args) @@ -427,6 +386,7 @@ func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interfac } rows.sql = ps.SQL + var err error args, err = convertDriverValuers(args) if err != nil { rows.fatal(err) @@ -461,31 +421,10 @@ func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interfac return rows, rows.err } -func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) { - if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" { - return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on") - } - - if c.pgConn.ParameterStatus("client_encoding") != "UTF8" { - return "", errors.New("simple protocol queries must be run with client_encoding=UTF8") - } - - var err error - valueArgs := make([]interface{}, len(args)) - for i, a := range args { - valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a) - if err != nil { - return "", err - } - } - - return sanitize.SanitizeSQL(sql, valueArgs...) -} - // QueryRow is a convenience wrapper over Query. Any error that occurs while // querying is deferred until calling Scan on the returned Row. That Row will // error with ErrNoRows if no rows are returned. -func (c *Conn) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) Row { - rows, _ := c.Query(ctx, sql, optionsAndArgs...) +func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { + rows, _ := c.Query(ctx, sql, args...) return (*connRow)(rows.(*connRows)) } diff --git a/query_test.go b/query_test.go index 0f320d8a..c1b89656 100644 --- a/query_test.go +++ b/query_test.go @@ -908,44 +908,6 @@ func TestQueryRowErrors(t *testing.T) { } } -func TestQueryRowExErrorsWrongParameterOIDs(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - sql := ` - with t as ( - select 1::int8 as some_int, 'foo'::text as some_text - ) - select some_int from t where some_text = $1` - paramOIDs := []pgtype.OID{pgtype.TextArrayOID} - queryArgs := []interface{}{"bar"} - queryOptions := &pgx.QueryExOptions{ - ParameterOIDs: paramOIDs, - ResultFormatCodes: []int16{pgx.BinaryFormatCode}, - } - optionsAndArgs := append([]interface{}{queryOptions}, queryArgs...) - expectedErr := "operator does not exist: text = text[] (SQLSTATE 42883)" - var result int64 - - err := conn.QueryRow( - context.Background(), - sql, - optionsAndArgs..., - ).Scan(&result) - - if err == nil { - t.Errorf("Unexpected success (sql -> %v, paramOIDs -> %v, queryArgs -> %v)", sql, paramOIDs, queryArgs) - } - if err != nil && !strings.Contains(err.Error(), expectedErr) { - t.Errorf("Expected error to contain %s, but got %v (sql -> %v, paramOIDs -> %v, queryArgs -> %v)", - expectedErr, err, sql, paramOIDs, queryArgs) - } - - ensureConnValid(t, conn) -} - func TestQueryRowNoResults(t *testing.T) { t.Parallel() @@ -1302,273 +1264,6 @@ func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) { ensureConnValid(t, conn) } -func TestConnQueryRowSingleRoundTrip(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - var result int32 - err := conn.QueryRow( - context.Background(), - "select $1 + $2", - &pgx.QueryExOptions{ - ParameterOIDs: []pgtype.OID{pgtype.Int4OID, pgtype.Int4OID}, - ResultFormatCodes: []int16{pgx.BinaryFormatCode}, - }, - 1, 2, - ).Scan(&result) - if err != nil { - t.Fatal(err) - } - if result != 3 { - t.Fatalf("result => %d, want %d", result, 3) - } - - ensureConnValid(t, conn) -} - -func TestConnSimpleProtocol(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - // Test all supported low-level types - - { - expected := int64(42) - var actual int64 - err := conn.QueryRow( - context.Background(), - "select $1::int8", - &pgx.QueryExOptions{SimpleProtocol: true}, - expected, - ).Scan(&actual) - if err != nil { - t.Error(err) - } - if expected != actual { - t.Errorf("expected %v got %v", expected, actual) - } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } - } - - { - expected := float64(1.23) - var actual float64 - err := conn.QueryRow( - context.Background(), - "select $1::float8", - &pgx.QueryExOptions{SimpleProtocol: true}, - expected, - ).Scan(&actual) - if err != nil { - t.Error(err) - } - if expected != actual { - t.Errorf("expected %v got %v", expected, actual) - } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } - } - - { - expected := true - var actual bool - err := conn.QueryRow( - context.Background(), - "select $1", - &pgx.QueryExOptions{SimpleProtocol: true}, - expected, - ).Scan(&actual) - if err != nil { - t.Error(err) - } - if expected != actual { - t.Errorf("expected %v got %v", expected, actual) - } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } - } - - { - expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95} - var actual []byte - err := conn.QueryRow( - context.Background(), - "select $1::bytea", - &pgx.QueryExOptions{SimpleProtocol: true}, - expected, - ).Scan(&actual) - if err != nil { - t.Error(err) - } - if bytes.Compare(actual, expected) != 0 { - t.Errorf("expected %v got %v", expected, actual) - } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } - } - - { - expected := "test" - var actual string - err := conn.QueryRow( - context.Background(), - "select $1::text", - &pgx.QueryExOptions{SimpleProtocol: true}, - expected, - ).Scan(&actual) - if err != nil { - t.Error(err) - } - if expected != actual { - t.Errorf("expected %v got %v", expected, actual) - } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } - } - - // Test high-level type - - { - expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present} - actual := expected - err := conn.QueryRow( - context.Background(), - "select $1::circle", - &pgx.QueryExOptions{SimpleProtocol: true}, - &expected, - ).Scan(&actual) - if err != nil { - t.Error(err) - } - if expected != actual { - t.Errorf("expected %v got %v", expected, actual) - } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } - } - - // Test multiple args in single query - - { - expectedInt64 := int64(234423) - expectedFloat64 := float64(-0.2312) - expectedBool := true - expectedBytes := []byte{255, 0, 23, 16, 87, 45, 9, 23, 45, 223} - expectedString := "test" - var actualInt64 int64 - var actualFloat64 float64 - var actualBool bool - var actualBytes []byte - var actualString string - err := conn.QueryRow( - context.Background(), - "select $1::int8, $2::float8, $3, $4::bytea, $5::text", - &pgx.QueryExOptions{SimpleProtocol: true}, - expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString, - ).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString) - if err != nil { - t.Error(err) - } - if expectedInt64 != actualInt64 { - t.Errorf("expected %v got %v", expectedInt64, actualInt64) - } - if expectedFloat64 != actualFloat64 { - t.Errorf("expected %v got %v", expectedFloat64, actualFloat64) - } - if expectedBool != actualBool { - t.Errorf("expected %v got %v", expectedBool, actualBool) - } - if bytes.Compare(expectedBytes, actualBytes) != 0 { - t.Errorf("expected %v got %v", expectedBytes, actualBytes) - } - if expectedString != actualString { - t.Errorf("expected %v got %v", expectedString, actualString) - } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } - } - - // Test dangerous cases - - { - expected := "foo';drop table users;" - var actual string - err := conn.QueryRow( - context.Background(), - "select $1", - &pgx.QueryExOptions{SimpleProtocol: true}, - expected, - ).Scan(&actual) - if err != nil { - t.Error(err) - } - if expected != actual { - t.Errorf("expected %v got %v", expected, actual) - } - if !conn.LastStmtSent() { - t.Error("Expected LastStmtSent to return true") - } - } - - ensureConnValid(t, conn) -} - -func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - mustExec(t, conn, "set client_encoding to 'SQL_ASCII'") - - var expected string - err := conn.QueryRow( - context.Background(), - "select $1", - &pgx.QueryExOptions{SimpleProtocol: true}, - "test", - ).Scan(&expected) - if err == nil { - t.Error("expected error when client_encoding not UTF8, but no error occurred") - } - - ensureConnValid(t, conn) -} - -func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - mustExec(t, conn, "set standard_conforming_strings to off") - - var expected string - err := conn.QueryRow( - context.Background(), - "select $1", - &pgx.QueryExOptions{SimpleProtocol: true}, - `\'; drop table users; --`, - ).Scan(&expected) - if err == nil { - t.Error("expected error when standard_conforming_strings is off, but no error occurred") - } - - ensureConnValid(t, conn) -} - func TestQueryCloseBefore(t *testing.T) { t.Parallel() diff --git a/stdlib/sql.go b/stdlib/sql.go index 0ba56510..317d7af4 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -239,40 +239,22 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na return nil, driver.ErrBadConn } - var rows pgx.Rows + // TODO - remove hack that creates a new prepared statement for every query -- put in place because of problem preparing empty statement name + psname := fmt.Sprintf("stdlibpx%v", &argsV) - if !c.connConfig.PreferSimpleProtocol { - // TODO - remove hack that creates a new prepared statement for every query -- put in place because of problem preparing empty statement name - psname := fmt.Sprintf("stdlibpx%v", &argsV) - - ps, err := c.conn.PrepareEx(ctx, psname, query, nil) - if err != nil { - // since PrepareEx failed, we didn't actually get to send the values, so - // we can safely retry - if _, is := err.(net.Error); is { - return nil, driver.ErrBadConn - } - return nil, err - } - - restrictBinaryToDatabaseSqlTypes(ps) - return c.queryPreparedContext(ctx, psname, argsV) - } - - rows, err := c.conn.Query(ctx, query, namedValueToInterface(argsV)...) + ps, err := c.conn.PrepareEx(ctx, psname, query, nil) if err != nil { - // if we got a network error before we had a chance to send the query, retry - if !c.conn.LastStmtSent() { - if _, is := err.(net.Error); is { - return nil, driver.ErrBadConn - } + // since PrepareEx failed, we didn't actually get to send the values, so + // we can safely retry + if _, is := err.(net.Error); is { + return nil, driver.ErrBadConn } return nil, err } - // Preload first row because otherwise we won't know what columns are available when database/sql asks. - more := rows.Next() - return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil + restrictBinaryToDatabaseSqlTypes(ps) + return c.queryPreparedContext(ctx, psname, argsV) + } func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 429f4dce..2338e97f 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -1054,89 +1054,6 @@ func TestRowsColumnTypes(t *testing.T) { } } -func TestSimpleQueryLifeCycle(t *testing.T) { - // TODO - need to use new method of establishing connection with pgx specific configuration - - // driverConfig := stdlib.DriverConfig{ - // ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true}, - // } - - // stdlib.RegisterDriverConfig(&driverConfig) - // defer stdlib.UnregisterDriverConfig(&driverConfig) - - // db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")) - // if err != nil { - // t.Fatalf("sql.Open failed: %v", err) - // } - // defer closeDB(t, db) - - // rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3) - // if err != nil { - // t.Fatalf("stmt.Query unexpectedly failed: %v", err) - // } - - // rowCount := int64(0) - - // for rows.Next() { - // rowCount++ - // var ( - // s string - // n int64 - // ) - - // if err := rows.Scan(&s, &n); err != nil { - // t.Fatalf("rows.Scan unexpectedly failed: %v", err) - // } - - // if s != "foo" { - // t.Errorf(`Expected "foo", received "%v"`, s) - // } - - // if n != rowCount { - // t.Errorf("Expected %d, received %d", rowCount, n) - // } - // } - - // if err = rows.Err(); err != nil { - // t.Fatalf("rows.Err unexpectedly is: %v", err) - // } - - // if rowCount != 10 { - // t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) - // } - - // err = rows.Close() - // if err != nil { - // t.Fatalf("rows.Close unexpectedly failed: %v", err) - // } - - // rows, err = db.Query("select 1 where false") - // if err != nil { - // t.Fatalf("stmt.Query unexpectedly failed: %v", err) - // } - - // rowCount = int64(0) - - // for rows.Next() { - // rowCount++ - // } - - // if err = rows.Err(); err != nil { - // t.Fatalf("rows.Err unexpectedly is: %v", err) - // } - - // if rowCount != 0 { - // t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) - // } - - // err = rows.Close() - // if err != nil { - // t.Fatalf("rows.Close unexpectedly failed: %v", err) - // } - - // ensureConnValid(t, db) -} - // https://github.com/jackc/pgx/issues/409 func TestScanJSONIntoJSONRawMessage(t *testing.T) { db := openDB(t) diff --git a/tx.go b/tx.go index 19b9159b..96afadeb 100644 --- a/tx.go +++ b/tx.go @@ -172,19 +172,19 @@ func (tx *Tx) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOp } // Query delegates to the underlying *Conn -func (tx *Tx) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (Rows, error) { +func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { if tx.status != TxStatusInProgress { // Because checking for errors can be deferred to the *Rows, build one with the error err := ErrTxClosed return &connRows{closed: true, err: err}, err } - return tx.conn.Query(ctx, sql, optionsAndArgs...) + return tx.conn.Query(ctx, sql, args...) } // QueryRow delegates to the underlying *Conn -func (tx *Tx) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) Row { - rows, _ := tx.Query(ctx, sql, optionsAndArgs...) +func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { + rows, _ := tx.Query(ctx, sql, args...) return (*connRow)(rows.(*connRows)) }