mirror of https://github.com/jackc/pgx.git
Remove simple protocol and one round trip query options
It is impossible to guarantee that the a query executed with the simple protocol will behave the same as with the extended protocol. This is because the normal pgx path relies on knowing the OID of query parameters. Without this encoding a value can only be determined by the value instead of the combination of value and PostgreSQL type. For example, how should a []int32 be encoded? It might be encoded into a PostgreSQL int4[] or json. Removal also simplifies the core query path. The primary reason for the simple protocol is for servers like PgBouncer that may not be able to support normal prepared statements. After further research it appears that issuing a "flush" instead "sync" after preparing the unnamed statement would allow PgBouncer to work. The one round trip mode can be better handled with prepared statements. As a last resort, all original server functionality can still be accessed by dropping down to PgConn.pull/483/head
parent
5a374c467f
commit
c53c9e6eb5
11
conn.go
11
conn.go
|
@ -45,17 +45,6 @@ type ConnConfig struct {
|
|||
Logger Logger
|
||||
LogLevel LogLevel
|
||||
CustomConnInfo func(*Conn) (*pgtype.ConnInfo, error) // Callback function to implement connection strategies for different backends. crate, pgbouncer, pgpool, etc.
|
||||
|
||||
// PreferSimpleProtocol disables implicit prepared statement usage. By default
|
||||
// pgx automatically uses the unnamed prepared statement for Query and
|
||||
// QueryRow. It also uses a prepared statement when Exec has arguments. This
|
||||
// can improve performance due to being able to use the binary format. It also
|
||||
// does not rely on client side parameter sanitization. However, it does incur
|
||||
// two round-trips per query and may be incompatible proxies such as
|
||||
// PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be
|
||||
// used by default. The same functionality can be controlled on a per query
|
||||
// basis by setting QueryExOptions.SimpleProtocol.
|
||||
PreferSimpleProtocol bool
|
||||
}
|
||||
|
||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
|
|
63
conn_test.go
63
conn_test.go
|
@ -78,31 +78,6 @@ func TestConnect(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestConnectWithPreferSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
connConfig.PreferSimpleProtocol = true
|
||||
|
||||
conn := mustConnect(t, connConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
// If simple protocol is used we should be able to correctly scan the result
|
||||
// into a pgtype.Text as the integer will have been encoded in text.
|
||||
|
||||
var s pgtype.Text
|
||||
err := conn.QueryRow(context.Background(), "select $1::int4", 42).Scan(&s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if s.Get() != "42" {
|
||||
t.Fatalf(`expected "42", got %v`, s)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestExec(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -285,44 +260,6 @@ func TestExecExtendedProtocol(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestExecSimpleProtocol(t *testing.T) {
|
||||
t.Skip("TODO when with simple protocol supported in connection")
|
||||
// t.Parallel()
|
||||
|
||||
// conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
// defer closeConn(t, conn)
|
||||
|
||||
// ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
// defer cancelFunc()
|
||||
|
||||
// commandTag, err := conn.ExecEx(ctx, "create temporary table foo(name varchar primary key);", nil)
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// if string(commandTag) != "CREATE TABLE" {
|
||||
// t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
// }
|
||||
// if !conn.LastStmtSent() {
|
||||
// t.Error("Expected LastStmtSent to return true")
|
||||
// }
|
||||
|
||||
// commandTag, err = conn.ExecEx(
|
||||
// ctx,
|
||||
// "insert into foo(name) values($1);",
|
||||
// &pgx.QueryExOptions{SimpleProtocol: true},
|
||||
// "bar'; drop table foo;--",
|
||||
// )
|
||||
// if err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
// if string(commandTag) != "INSERT 0 1" {
|
||||
// t.Fatalf("Unexpected results from ExecEx: %v", commandTag)
|
||||
// }
|
||||
// if !conn.LastStmtSent() {
|
||||
// t.Error("Expected LastStmtSent to return true")
|
||||
// }
|
||||
}
|
||||
|
||||
func TestExecExFailureCloseBefore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -1,237 +0,0 @@
|
|||
package sanitize
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Part is either a string or an int. A string is raw SQL. An int is a
|
||||
// argument placeholder.
|
||||
type Part interface{}
|
||||
|
||||
type Query struct {
|
||||
Parts []Part
|
||||
}
|
||||
|
||||
func (q *Query) Sanitize(args ...interface{}) (string, error) {
|
||||
argUse := make([]bool, len(args))
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
for _, part := range q.Parts {
|
||||
var str string
|
||||
switch part := part.(type) {
|
||||
case string:
|
||||
str = part
|
||||
case int:
|
||||
argIdx := part - 1
|
||||
if argIdx >= len(args) {
|
||||
return "", errors.Errorf("insufficient arguments")
|
||||
}
|
||||
arg := args[argIdx]
|
||||
switch arg := arg.(type) {
|
||||
case nil:
|
||||
str = "null"
|
||||
case int64:
|
||||
str = strconv.FormatInt(arg, 10)
|
||||
case float64:
|
||||
str = strconv.FormatFloat(arg, 'f', -1, 64)
|
||||
case bool:
|
||||
str = strconv.FormatBool(arg)
|
||||
case []byte:
|
||||
str = QuoteBytes(arg)
|
||||
case string:
|
||||
str = QuoteString(arg)
|
||||
case time.Time:
|
||||
str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
|
||||
default:
|
||||
return "", errors.Errorf("invalid arg type: %T", arg)
|
||||
}
|
||||
argUse[argIdx] = true
|
||||
default:
|
||||
return "", errors.Errorf("invalid Part type: %T", part)
|
||||
}
|
||||
buf.WriteString(str)
|
||||
}
|
||||
|
||||
for i, used := range argUse {
|
||||
if !used {
|
||||
return "", errors.Errorf("unused argument: %d", i)
|
||||
}
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func NewQuery(sql string) (*Query, error) {
|
||||
l := &sqlLexer{
|
||||
src: sql,
|
||||
stateFn: rawState,
|
||||
}
|
||||
|
||||
for l.stateFn != nil {
|
||||
l.stateFn = l.stateFn(l)
|
||||
}
|
||||
|
||||
query := &Query{Parts: l.parts}
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func QuoteString(str string) string {
|
||||
return "'" + strings.Replace(str, "'", "''", -1) + "'"
|
||||
}
|
||||
|
||||
func QuoteBytes(buf []byte) string {
|
||||
return `'\x` + hex.EncodeToString(buf) + "'"
|
||||
}
|
||||
|
||||
type sqlLexer struct {
|
||||
src string
|
||||
start int
|
||||
pos int
|
||||
stateFn stateFn
|
||||
parts []Part
|
||||
}
|
||||
|
||||
type stateFn func(*sqlLexer) stateFn
|
||||
|
||||
func rawState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case 'e', 'E':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune == '\'' {
|
||||
l.pos += width
|
||||
return escapeStringState
|
||||
}
|
||||
case '\'':
|
||||
return singleQuoteState
|
||||
case '"':
|
||||
return doubleQuoteState
|
||||
case '$':
|
||||
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if '0' <= nextRune && nextRune <= '9' {
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos-width])
|
||||
}
|
||||
l.start = l.pos
|
||||
return placeholderState
|
||||
}
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func singleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func doubleQuoteState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '"':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '"' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// placeholderState consumes a placeholder value. The $ must have already has
|
||||
// already been consumed. The first rune must be a digit.
|
||||
func placeholderState(l *sqlLexer) stateFn {
|
||||
num := 0
|
||||
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
if '0' <= r && r <= '9' {
|
||||
num *= 10
|
||||
num += int(r - '0')
|
||||
} else {
|
||||
l.parts = append(l.parts, num)
|
||||
l.pos -= width
|
||||
l.start = l.pos
|
||||
return rawState
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func escapeStringState(l *sqlLexer) stateFn {
|
||||
for {
|
||||
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
|
||||
switch r {
|
||||
case '\\':
|
||||
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
l.pos += width
|
||||
case '\'':
|
||||
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
|
||||
if nextRune != '\'' {
|
||||
return rawState
|
||||
}
|
||||
l.pos += width
|
||||
case utf8.RuneError:
|
||||
if l.pos-l.start > 0 {
|
||||
l.parts = append(l.parts, l.src[l.start:l.pos])
|
||||
l.start = l.pos
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
|
||||
// as necessary. This function is only safe when standard_conforming_strings is
|
||||
// on.
|
||||
func SanitizeSQL(sql string, args ...interface{}) (string, error) {
|
||||
query, err := NewQuery(sql)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return query.Sanitize(args...)
|
||||
}
|
|
@ -1,175 +0,0 @@
|
|||
package sanitize_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/jackc/pgx/internal/sanitize"
|
||||
)
|
||||
|
||||
func TestNewQuery(t *testing.T) {
|
||||
successTests := []struct {
|
||||
sql string
|
||||
expected sanitize.Query
|
||||
}{
|
||||
{
|
||||
sql: "select 42",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
|
||||
},
|
||||
{
|
||||
sql: "select $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
},
|
||||
{
|
||||
sql: "select 'quoted $42', $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}},
|
||||
},
|
||||
{
|
||||
sql: `select "doubled quoted $42", $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}},
|
||||
},
|
||||
{
|
||||
sql: "select 'foo''bar', $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}},
|
||||
},
|
||||
{
|
||||
sql: `select "foo""bar", $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}},
|
||||
},
|
||||
{
|
||||
sql: "select '''', $1",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}},
|
||||
},
|
||||
{
|
||||
sql: `select """", $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}},
|
||||
},
|
||||
{
|
||||
sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11",
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}},
|
||||
},
|
||||
{
|
||||
sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}},
|
||||
},
|
||||
{
|
||||
sql: `select E'escape string\' $42', $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}},
|
||||
},
|
||||
{
|
||||
sql: `select e'escape string\' $42', $1`,
|
||||
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successTests {
|
||||
query, err := sanitize.NewQuery(tt.sql)
|
||||
if err != nil {
|
||||
t.Errorf("%d. %v", i, err)
|
||||
}
|
||||
|
||||
if len(query.Parts) == len(tt.expected.Parts) {
|
||||
for j := range query.Parts {
|
||||
if query.Parts[j] != tt.expected.Parts[j] {
|
||||
t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j])
|
||||
}
|
||||
}
|
||||
} else {
|
||||
t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQuerySanitize(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
query sanitize.Query
|
||||
args []interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
|
||||
args: []interface{}{},
|
||||
expected: `select 42`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `select 42`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{float64(1.23)},
|
||||
expected: `select 1.23`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{true},
|
||||
expected: `select true`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{[]byte{0, 1, 2, 3, 255}},
|
||||
expected: `select '\x00010203ff'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{nil},
|
||||
expected: `select null`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{"foobar"},
|
||||
expected: `select 'foobar'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{"foo'bar"},
|
||||
expected: `select 'foo''bar'`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{`foo\'bar`},
|
||||
expected: `select 'foo\''bar'`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range successfulTests {
|
||||
actual, err := tt.query.Sanitize(tt.args...)
|
||||
if err != nil {
|
||||
t.Errorf("%d. %v", i, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if tt.expected != actual {
|
||||
t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
errorTests := []struct {
|
||||
query sanitize.Query
|
||||
args []interface{}
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `insufficient arguments`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}},
|
||||
args: []interface{}{int64(42)},
|
||||
expected: `unused argument: 0`,
|
||||
},
|
||||
{
|
||||
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
|
||||
args: []interface{}{42},
|
||||
expected: `invalid arg type: int`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range errorTests {
|
||||
_, err := tt.query.Sanitize(tt.args...)
|
||||
if err == nil || err.Error() != tt.expected {
|
||||
t.Errorf("%d. expected error %v, got %v", i, tt.expected, err)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,9 +20,6 @@ func TestCIDTranscode(t *testing.T) {
|
|||
|
||||
testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
||||
|
||||
// No direct conversion from int to cid, convert through text
|
||||
testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc)
|
||||
|
||||
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
||||
testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
|
||||
}
|
||||
|
|
|
@ -98,7 +98,6 @@ func TestSuccessfulTranscode(t testing.TB, pgTypeName string, values []interface
|
|||
|
||||
func TestSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
||||
TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
||||
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
||||
TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
|
||||
}
|
||||
|
@ -150,35 +149,6 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []
|
|||
}
|
||||
}
|
||||
|
||||
func TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
conn := MustConnectPgx(t)
|
||||
defer MustCloseContext(t, conn)
|
||||
|
||||
for i, v := range values {
|
||||
// Derefence value if it is a pointer
|
||||
derefV := v
|
||||
refVal := reflect.ValueOf(v)
|
||||
if refVal.Kind() == reflect.Ptr {
|
||||
derefV = refVal.Elem().Interface()
|
||||
}
|
||||
|
||||
result := reflect.New(reflect.TypeOf(derefV))
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
fmt.Sprintf("select ($1)::%s", pgTypeName),
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
v,
|
||||
).Scan(result.Interface())
|
||||
if err != nil {
|
||||
t.Errorf("Simple protocol %d: %v", i, err)
|
||||
}
|
||||
|
||||
if !eqFunc(result.Elem().Interface(), derefV) {
|
||||
t.Errorf("Simple protocol %d: expected %v, got %v", i, derefV, result.Elem().Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatabaseSQLSuccessfulTranscodeEqFunc(t testing.TB, driverName, pgTypeName string, values []interface{}, eqFunc func(a, b interface{}) bool) {
|
||||
conn := MustConnectDatabaseSQL(t, driverName)
|
||||
defer MustClose(t, conn)
|
||||
|
|
|
@ -20,9 +20,6 @@ func TestXIDTranscode(t *testing.T) {
|
|||
|
||||
testutil.TestPgxSuccessfulTranscodeEqFunc(t, pgTypeName, values, eqFunc)
|
||||
|
||||
// No direct conversion from int to xid, convert through text
|
||||
testutil.TestPgxSimpleProtocolSuccessfulTranscodeEqFunc(t, "text::"+pgTypeName, values, eqFunc)
|
||||
|
||||
for _, driverName := range []string{"github.com/lib/pq", "github.com/jackc/pgx/stdlib"} {
|
||||
testutil.TestDatabaseSQLSuccessfulTranscodeEqFunc(t, driverName, pgTypeName, values, eqFunc)
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ func testExec(t *testing.T, db execer) {
|
|||
}
|
||||
|
||||
type queryer interface {
|
||||
Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error)
|
||||
Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error)
|
||||
}
|
||||
|
||||
func testQuery(t *testing.T, db queryer) {
|
||||
|
@ -59,7 +59,7 @@ func testQuery(t *testing.T, db queryer) {
|
|||
}
|
||||
|
||||
type queryRower interface {
|
||||
QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row
|
||||
QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
|
||||
}
|
||||
|
||||
func testQueryRow(t *testing.T, db queryRower) {
|
||||
|
|
|
@ -53,12 +53,12 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (
|
|||
return c.Conn().Exec(ctx, sql, arguments...)
|
||||
}
|
||||
|
||||
func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) {
|
||||
return c.Conn().Query(ctx, sql, optionsAndArgs...)
|
||||
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
|
||||
return c.Conn().Query(ctx, sql, args...)
|
||||
}
|
||||
|
||||
func (c *Conn) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row {
|
||||
return c.Conn().QueryRow(ctx, sql, optionsAndArgs...)
|
||||
func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
|
||||
return c.Conn().QueryRow(ctx, sql, args...)
|
||||
}
|
||||
|
||||
func (c *Conn) Begin() (*pgx.Tx, error) {
|
||||
|
|
|
@ -68,13 +68,13 @@ func (p *Pool) Exec(ctx context.Context, sql string, arguments ...interface{}) (
|
|||
return c.Exec(ctx, sql, arguments...)
|
||||
}
|
||||
|
||||
func (p *Pool) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) {
|
||||
func (p *Pool) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
|
||||
c, err := p.Acquire(ctx)
|
||||
if err != nil {
|
||||
return errRows{err: err}, err
|
||||
}
|
||||
|
||||
rows, err := c.Query(ctx, sql, optionsAndArgs...)
|
||||
rows, err := c.Query(ctx, sql, args...)
|
||||
if err != nil {
|
||||
c.Release()
|
||||
return errRows{err: err}, err
|
||||
|
@ -83,13 +83,13 @@ func (p *Pool) Query(ctx context.Context, sql string, optionsAndArgs ...interfac
|
|||
return &poolRows{r: rows, c: c}, nil
|
||||
}
|
||||
|
||||
func (p *Pool) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row {
|
||||
func (p *Pool) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
|
||||
c, err := p.Acquire(ctx)
|
||||
if err != nil {
|
||||
return errRow{err: err}
|
||||
}
|
||||
|
||||
row := c.QueryRow(ctx, sql, optionsAndArgs...)
|
||||
row := c.QueryRow(ctx, sql, args...)
|
||||
return &poolRow{r: row, c: c}
|
||||
}
|
||||
|
||||
|
|
|
@ -38,10 +38,10 @@ func (tx *Tx) Exec(ctx context.Context, sql string, arguments ...interface{}) (p
|
|||
return tx.c.Exec(ctx, sql, arguments...)
|
||||
}
|
||||
|
||||
func (tx *Tx) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (pgx.Rows, error) {
|
||||
return tx.c.Query(ctx, sql, optionsAndArgs...)
|
||||
func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) {
|
||||
return tx.c.Query(ctx, sql, args...)
|
||||
}
|
||||
|
||||
func (tx *Tx) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) pgx.Row {
|
||||
return tx.c.QueryRow(ctx, sql, optionsAndArgs...)
|
||||
func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
|
||||
return tx.c.QueryRow(ctx, sql, args...)
|
||||
}
|
||||
|
|
69
query.go
69
query.go
|
@ -9,7 +9,6 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx/internal/sanitize"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
|
||||
|
@ -288,31 +287,12 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows {
|
|||
return r
|
||||
}
|
||||
|
||||
type QueryExOptions struct {
|
||||
// When ParameterOIDs are present and the query is not a prepared statement,
|
||||
// then ParameterOIDs and ResultFormatCodes will be used to avoid an extra
|
||||
// network round-trip.
|
||||
ParameterOIDs []pgtype.OID
|
||||
ResultFormatCodes []int16
|
||||
|
||||
SimpleProtocol bool
|
||||
}
|
||||
|
||||
// Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is
|
||||
// allowed to ignore the error returned from Query and handle it in Rows.
|
||||
func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (Rows, error) {
|
||||
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
|
||||
c.lastStmtSent = false
|
||||
// rows = c.getRows(sql, args)
|
||||
|
||||
var options *QueryExOptions
|
||||
args := optionsAndArgs
|
||||
if len(optionsAndArgs) > 0 {
|
||||
if o, ok := optionsAndArgs[0].(*QueryExOptions); ok {
|
||||
options = o
|
||||
args = optionsAndArgs[1:]
|
||||
}
|
||||
}
|
||||
|
||||
rows := &connRows{
|
||||
conn: c,
|
||||
startTime: time.Now(),
|
||||
|
@ -332,27 +312,6 @@ func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interfac
|
|||
// return rows, rows.err
|
||||
// }
|
||||
|
||||
var err error
|
||||
if (options == nil && c.config.PreferSimpleProtocol) || (options != nil && options.SimpleProtocol) {
|
||||
sql, err = c.sanitizeForSimpleQuery(sql, args...)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
c.lastStmtSent = true
|
||||
rows.multiResultReader = c.pgConn.Exec(ctx, sql)
|
||||
if rows.multiResultReader.NextResult() {
|
||||
rows.resultReader = rows.multiResultReader.ResultReader()
|
||||
} else {
|
||||
err = rows.multiResultReader.Close()
|
||||
rows.fatal(err)
|
||||
return rows, err
|
||||
}
|
||||
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// if options != nil && len(options.ParameterOIDs) > 0 {
|
||||
|
||||
// buf, err := c.buildOneRoundTripQueryEx(c.wbuf, sql, options, args)
|
||||
|
@ -427,6 +386,7 @@ func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interfac
|
|||
}
|
||||
rows.sql = ps.SQL
|
||||
|
||||
var err error
|
||||
args, err = convertDriverValuers(args)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
|
@ -461,31 +421,10 @@ func (c *Conn) Query(ctx context.Context, sql string, optionsAndArgs ...interfac
|
|||
return rows, rows.err
|
||||
}
|
||||
|
||||
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) {
|
||||
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
|
||||
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
|
||||
}
|
||||
|
||||
if c.pgConn.ParameterStatus("client_encoding") != "UTF8" {
|
||||
return "", errors.New("simple protocol queries must be run with client_encoding=UTF8")
|
||||
}
|
||||
|
||||
var err error
|
||||
valueArgs := make([]interface{}, len(args))
|
||||
for i, a := range args {
|
||||
valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
return sanitize.SanitizeSQL(sql, valueArgs...)
|
||||
}
|
||||
|
||||
// QueryRow is a convenience wrapper over Query. Any error that occurs while
|
||||
// querying is deferred until calling Scan on the returned Row. That Row will
|
||||
// error with ErrNoRows if no rows are returned.
|
||||
func (c *Conn) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) Row {
|
||||
rows, _ := c.Query(ctx, sql, optionsAndArgs...)
|
||||
func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row {
|
||||
rows, _ := c.Query(ctx, sql, args...)
|
||||
return (*connRow)(rows.(*connRows))
|
||||
}
|
||||
|
|
305
query_test.go
305
query_test.go
|
@ -908,44 +908,6 @@ func TestQueryRowErrors(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestQueryRowExErrorsWrongParameterOIDs(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
sql := `
|
||||
with t as (
|
||||
select 1::int8 as some_int, 'foo'::text as some_text
|
||||
)
|
||||
select some_int from t where some_text = $1`
|
||||
paramOIDs := []pgtype.OID{pgtype.TextArrayOID}
|
||||
queryArgs := []interface{}{"bar"}
|
||||
queryOptions := &pgx.QueryExOptions{
|
||||
ParameterOIDs: paramOIDs,
|
||||
ResultFormatCodes: []int16{pgx.BinaryFormatCode},
|
||||
}
|
||||
optionsAndArgs := append([]interface{}{queryOptions}, queryArgs...)
|
||||
expectedErr := "operator does not exist: text = text[] (SQLSTATE 42883)"
|
||||
var result int64
|
||||
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
sql,
|
||||
optionsAndArgs...,
|
||||
).Scan(&result)
|
||||
|
||||
if err == nil {
|
||||
t.Errorf("Unexpected success (sql -> %v, paramOIDs -> %v, queryArgs -> %v)", sql, paramOIDs, queryArgs)
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), expectedErr) {
|
||||
t.Errorf("Expected error to contain %s, but got %v (sql -> %v, paramOIDs -> %v, queryArgs -> %v)",
|
||||
expectedErr, err, sql, paramOIDs, queryArgs)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryRowNoResults(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -1302,273 +1264,6 @@ func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) {
|
|||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnQueryRowSingleRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
var result int32
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1 + $2",
|
||||
&pgx.QueryExOptions{
|
||||
ParameterOIDs: []pgtype.OID{pgtype.Int4OID, pgtype.Int4OID},
|
||||
ResultFormatCodes: []int16{pgx.BinaryFormatCode},
|
||||
},
|
||||
1, 2,
|
||||
).Scan(&result)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if result != 3 {
|
||||
t.Fatalf("result => %d, want %d", result, 3)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocol(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
// Test all supported low-level types
|
||||
|
||||
{
|
||||
expected := int64(42)
|
||||
var actual int64
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::int8",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := float64(1.23)
|
||||
var actual float64
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::float8",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := true
|
||||
var actual bool
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95}
|
||||
var actual []byte
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::bytea",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if bytes.Compare(actual, expected) != 0 {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
expected := "test"
|
||||
var actual string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::text",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
// Test high-level type
|
||||
|
||||
{
|
||||
expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present}
|
||||
actual := expected
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::circle",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
&expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
// Test multiple args in single query
|
||||
|
||||
{
|
||||
expectedInt64 := int64(234423)
|
||||
expectedFloat64 := float64(-0.2312)
|
||||
expectedBool := true
|
||||
expectedBytes := []byte{255, 0, 23, 16, 87, 45, 9, 23, 45, 223}
|
||||
expectedString := "test"
|
||||
var actualInt64 int64
|
||||
var actualFloat64 float64
|
||||
var actualBool bool
|
||||
var actualBytes []byte
|
||||
var actualString string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1::int8, $2::float8, $3, $4::bytea, $5::text",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString,
|
||||
).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expectedInt64 != actualInt64 {
|
||||
t.Errorf("expected %v got %v", expectedInt64, actualInt64)
|
||||
}
|
||||
if expectedFloat64 != actualFloat64 {
|
||||
t.Errorf("expected %v got %v", expectedFloat64, actualFloat64)
|
||||
}
|
||||
if expectedBool != actualBool {
|
||||
t.Errorf("expected %v got %v", expectedBool, actualBool)
|
||||
}
|
||||
if bytes.Compare(expectedBytes, actualBytes) != 0 {
|
||||
t.Errorf("expected %v got %v", expectedBytes, actualBytes)
|
||||
}
|
||||
if expectedString != actualString {
|
||||
t.Errorf("expected %v got %v", expectedString, actualString)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
// Test dangerous cases
|
||||
|
||||
{
|
||||
expected := "foo';drop table users;"
|
||||
var actual string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
expected,
|
||||
).Scan(&actual)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if expected != actual {
|
||||
t.Errorf("expected %v got %v", expected, actual)
|
||||
}
|
||||
if !conn.LastStmtSent() {
|
||||
t.Error("Expected LastStmtSent to return true")
|
||||
}
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "set client_encoding to 'SQL_ASCII'")
|
||||
|
||||
var expected string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
"test",
|
||||
).Scan(&expected)
|
||||
if err == nil {
|
||||
t.Error("expected error when client_encoding not UTF8, but no error occurred")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, "set standard_conforming_strings to off")
|
||||
|
||||
var expected string
|
||||
err := conn.QueryRow(
|
||||
context.Background(),
|
||||
"select $1",
|
||||
&pgx.QueryExOptions{SimpleProtocol: true},
|
||||
`\'; drop table users; --`,
|
||||
).Scan(&expected)
|
||||
if err == nil {
|
||||
t.Error("expected error when standard_conforming_strings is off, but no error occurred")
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryCloseBefore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
|
@ -239,40 +239,22 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
|
|||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
var rows pgx.Rows
|
||||
// TODO - remove hack that creates a new prepared statement for every query -- put in place because of problem preparing empty statement name
|
||||
psname := fmt.Sprintf("stdlibpx%v", &argsV)
|
||||
|
||||
if !c.connConfig.PreferSimpleProtocol {
|
||||
// TODO - remove hack that creates a new prepared statement for every query -- put in place because of problem preparing empty statement name
|
||||
psname := fmt.Sprintf("stdlibpx%v", &argsV)
|
||||
|
||||
ps, err := c.conn.PrepareEx(ctx, psname, query, nil)
|
||||
if err != nil {
|
||||
// since PrepareEx failed, we didn't actually get to send the values, so
|
||||
// we can safely retry
|
||||
if _, is := err.(net.Error); is {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
return c.queryPreparedContext(ctx, psname, argsV)
|
||||
}
|
||||
|
||||
rows, err := c.conn.Query(ctx, query, namedValueToInterface(argsV)...)
|
||||
ps, err := c.conn.PrepareEx(ctx, psname, query, nil)
|
||||
if err != nil {
|
||||
// if we got a network error before we had a chance to send the query, retry
|
||||
if !c.conn.LastStmtSent() {
|
||||
if _, is := err.(net.Error); is {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
// since PrepareEx failed, we didn't actually get to send the values, so
|
||||
// we can safely retry
|
||||
if _, is := err.(net.Error); is {
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
|
||||
more := rows.Next()
|
||||
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
return c.queryPreparedContext(ctx, psname, argsV)
|
||||
|
||||
}
|
||||
|
||||
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
|
||||
|
|
|
@ -1054,89 +1054,6 @@ func TestRowsColumnTypes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestSimpleQueryLifeCycle(t *testing.T) {
|
||||
// TODO - need to use new method of establishing connection with pgx specific configuration
|
||||
|
||||
// driverConfig := stdlib.DriverConfig{
|
||||
// ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true},
|
||||
// }
|
||||
|
||||
// stdlib.RegisterDriverConfig(&driverConfig)
|
||||
// defer stdlib.UnregisterDriverConfig(&driverConfig)
|
||||
|
||||
// db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
|
||||
// if err != nil {
|
||||
// t.Fatalf("sql.Open failed: %v", err)
|
||||
// }
|
||||
// defer closeDB(t, db)
|
||||
|
||||
// rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
|
||||
// if err != nil {
|
||||
// t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||
// }
|
||||
|
||||
// rowCount := int64(0)
|
||||
|
||||
// for rows.Next() {
|
||||
// rowCount++
|
||||
// var (
|
||||
// s string
|
||||
// n int64
|
||||
// )
|
||||
|
||||
// if err := rows.Scan(&s, &n); err != nil {
|
||||
// t.Fatalf("rows.Scan unexpectedly failed: %v", err)
|
||||
// }
|
||||
|
||||
// if s != "foo" {
|
||||
// t.Errorf(`Expected "foo", received "%v"`, s)
|
||||
// }
|
||||
|
||||
// if n != rowCount {
|
||||
// t.Errorf("Expected %d, received %d", rowCount, n)
|
||||
// }
|
||||
// }
|
||||
|
||||
// if err = rows.Err(); err != nil {
|
||||
// t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||
// }
|
||||
|
||||
// if rowCount != 10 {
|
||||
// t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||
// }
|
||||
|
||||
// err = rows.Close()
|
||||
// if err != nil {
|
||||
// t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||
// }
|
||||
|
||||
// rows, err = db.Query("select 1 where false")
|
||||
// if err != nil {
|
||||
// t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||
// }
|
||||
|
||||
// rowCount = int64(0)
|
||||
|
||||
// for rows.Next() {
|
||||
// rowCount++
|
||||
// }
|
||||
|
||||
// if err = rows.Err(); err != nil {
|
||||
// t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||
// }
|
||||
|
||||
// if rowCount != 0 {
|
||||
// t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||
// }
|
||||
|
||||
// err = rows.Close()
|
||||
// if err != nil {
|
||||
// t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||
// }
|
||||
|
||||
// ensureConnValid(t, db)
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/409
|
||||
func TestScanJSONIntoJSONRawMessage(t *testing.T) {
|
||||
db := openDB(t)
|
||||
|
|
8
tx.go
8
tx.go
|
@ -172,19 +172,19 @@ func (tx *Tx) PrepareEx(ctx context.Context, name, sql string, opts *PrepareExOp
|
|||
}
|
||||
|
||||
// Query delegates to the underlying *Conn
|
||||
func (tx *Tx) Query(ctx context.Context, sql string, optionsAndArgs ...interface{}) (Rows, error) {
|
||||
func (tx *Tx) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
|
||||
if tx.status != TxStatusInProgress {
|
||||
// Because checking for errors can be deferred to the *Rows, build one with the error
|
||||
err := ErrTxClosed
|
||||
return &connRows{closed: true, err: err}, err
|
||||
}
|
||||
|
||||
return tx.conn.Query(ctx, sql, optionsAndArgs...)
|
||||
return tx.conn.Query(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// QueryRow delegates to the underlying *Conn
|
||||
func (tx *Tx) QueryRow(ctx context.Context, sql string, optionsAndArgs ...interface{}) Row {
|
||||
rows, _ := tx.Query(ctx, sql, optionsAndArgs...)
|
||||
func (tx *Tx) QueryRow(ctx context.Context, sql string, args ...interface{}) Row {
|
||||
rows, _ := tx.Query(ctx, sql, args...)
|
||||
return (*connRow)(rows.(*connRows))
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue