Restore simple protocol support

pull/586/head
Jack Christensen 2019-05-20 20:36:03 -05:00
parent 6d23b58b01
commit 29f02807b0
7 changed files with 861 additions and 3 deletions

92
conn.go
View File

@ -10,6 +10,7 @@ import (
"github.com/jackc/pgconn"
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v4/internal/sanitize"
)
const (
@ -24,6 +25,14 @@ type ConnConfig struct {
pgconn.Config
Logger Logger
LogLevel LogLevel
// PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended
// protocol. This can improve performance due to being able to use the binary format. It also does not rely on client
// side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement)
// and may be incompatible proxies such as PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be
// used by default. The same functionality can be controlled on a per query basis by setting
// QueryExOptions.SimpleProtocol.
PreferSimpleProtocol bool
}
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
@ -390,6 +399,36 @@ func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (
}
func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
simpleProtocol := c.config.PreferSimpleProtocol
optionLoop:
for len(arguments) > 0 {
switch arg := arguments[0].(type) {
case QuerySimpleProtocol:
simpleProtocol = bool(arg)
arguments = arguments[1:]
default:
break optionLoop
}
}
if simpleProtocol {
sql, err = c.sanitizeForSimpleQuery(sql, arguments...)
if err != nil {
return nil, err
}
mrr := c.pgConn.Exec(ctx, sql)
if mrr.NextResult() {
result := mrr.ResultReader().Read()
err = mrr.Close()
return result.CommandTag, err
} else {
err = mrr.Close()
return nil, err
}
}
c.eqb.Reset()
if ps, ok := c.preparedStatements[sql]; ok {
@ -495,6 +534,9 @@ func (c *Conn) getRows(sql string, args []interface{}) *connRows {
return r
}
// QuerySimpleProtocol controls whether the simple or extended protocol is used to send the query.
type QuerySimpleProtocol bool
// QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position.
type QueryResultFormats []int16
@ -506,6 +548,7 @@ type QueryResultFormatsByOID map[pgtype.OID]int16
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
var resultFormats QueryResultFormats
var resultFormatsByOID QueryResultFormatsByOID
simpleProtocol := c.config.PreferSimpleProtocol
optionLoop:
for len(args) > 0 {
@ -516,14 +559,39 @@ optionLoop:
case QueryResultFormatsByOID:
resultFormatsByOID = arg
args = args[1:]
case QuerySimpleProtocol:
simpleProtocol = bool(arg)
args = args[1:]
default:
break optionLoop
}
}
c.eqb.Reset()
rows := c.getRows(sql, args)
var err error
if simpleProtocol {
sql, err = c.sanitizeForSimpleQuery(sql, args...)
if err != nil {
rows.fatal(err)
return rows, err
}
mrr := c.pgConn.Exec(ctx, sql)
if mrr.NextResult() {
rows.resultReader = mrr.ResultReader()
rows.multiResultReader = mrr
} else {
err = mrr.Close()
rows.fatal(err)
return rows, err
}
return rows, nil
}
c.eqb.Reset()
ps, ok := c.preparedStatements[sql]
if !ok {
psd, err := c.pgConn.Prepare(ctx, "", sql, nil)
@ -550,7 +618,6 @@ optionLoop:
}
rows.sql = ps.SQL
var err error
args, err = convertDriverValuers(args)
if err != nil {
rows.fatal(err)
@ -663,3 +730,24 @@ func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
mrr: mrr,
}
}
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) {
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
}
if c.pgConn.ParameterStatus("client_encoding") != "UTF8" {
return "", errors.New("simple protocol queries must be run with client_encoding=UTF8")
}
var err error
valueArgs := make([]interface{}, len(args))
for i, a := range args {
valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a)
if err != nil {
return "", err
}
}
return sanitize.SanitizeSQL(sql, valueArgs...)
}

View File

@ -71,6 +71,31 @@ func TestConnect(t *testing.T) {
}
}
func TestConnectWithPreferSimpleProtocol(t *testing.T) {
t.Parallel()
connConfig := mustParseConfig(t, os.Getenv("PGX_TEST_DATABASE"))
connConfig.PreferSimpleProtocol = true
conn := mustConnect(t, connConfig)
defer closeConn(t, conn)
// If simple protocol is used we should be able to correctly scan the result
// into a pgtype.Text as the integer will have been encoded in text.
var s pgtype.Text
err := conn.QueryRow(context.Background(), "select $1::int4", 42).Scan(&s)
if err != nil {
t.Fatal(err)
}
if s.Get() != "42" {
t.Fatalf(`expected "42", got %v`, s)
}
ensureConnValid(t, conn)
}
func TestExec(t *testing.T) {
t.Parallel()
@ -251,6 +276,37 @@ func TestExecExtendedProtocol(t *testing.T) {
ensureConnValid(t, conn)
}
func TestExecSimpleProtocol(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
commandTag, err := conn.Exec(ctx, "create temporary table foo(name varchar primary key);")
if err != nil {
t.Fatal(err)
}
if string(commandTag) != "CREATE TABLE" {
t.Fatalf("Unexpected results from Exec: %v", commandTag)
}
commandTag, err = conn.Exec(ctx,
"insert into foo(name) values($1);",
pgx.QuerySimpleProtocol(true),
"bar'; drop table foo;--",
)
if err != nil {
t.Fatal(err)
}
if string(commandTag) != "INSERT 0 1" {
t.Fatalf("Unexpected results from Exec: %v", commandTag)
}
}
func TestPrepare(t *testing.T) {
t.Parallel()

View File

@ -0,0 +1,237 @@
package sanitize
import (
"bytes"
"encoding/hex"
"strconv"
"strings"
"time"
"unicode/utf8"
errors "golang.org/x/xerrors"
)
// Part is either a string or an int. A string is raw SQL. An int is a
// argument placeholder.
type Part interface{}
type Query struct {
Parts []Part
}
func (q *Query) Sanitize(args ...interface{}) (string, error) {
argUse := make([]bool, len(args))
buf := &bytes.Buffer{}
for _, part := range q.Parts {
var str string
switch part := part.(type) {
case string:
str = part
case int:
argIdx := part - 1
if argIdx >= len(args) {
return "", errors.Errorf("insufficient arguments")
}
arg := args[argIdx]
switch arg := arg.(type) {
case nil:
str = "null"
case int64:
str = strconv.FormatInt(arg, 10)
case float64:
str = strconv.FormatFloat(arg, 'f', -1, 64)
case bool:
str = strconv.FormatBool(arg)
case []byte:
str = QuoteBytes(arg)
case string:
str = QuoteString(arg)
case time.Time:
str = arg.Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", errors.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true
default:
return "", errors.Errorf("invalid Part type: %T", part)
}
buf.WriteString(str)
}
for i, used := range argUse {
if !used {
return "", errors.Errorf("unused argument: %d", i)
}
}
return buf.String(), nil
}
func NewQuery(sql string) (*Query, error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
}
for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}
query := &Query{Parts: l.parts}
return query, nil
}
func QuoteString(str string) string {
return "'" + strings.Replace(str, "'", "''", -1) + "'"
}
func QuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
}
type sqlLexer struct {
src string
start int
pos int
stateFn stateFn
parts []Part
}
type stateFn func(*sqlLexer) stateFn
func rawState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case 'e', 'E':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune == '\'' {
l.pos += width
return escapeStringState
}
case '\'':
return singleQuoteState
case '"':
return doubleQuoteState
case '$':
nextRune, _ := utf8.DecodeRuneInString(l.src[l.pos:])
if '0' <= nextRune && nextRune <= '9' {
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos-width])
}
l.start = l.pos
return placeholderState
}
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func singleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
func doubleQuoteState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '"':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '"' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
// placeholderState consumes a placeholder value. The $ must have already has
// already been consumed. The first rune must be a digit.
func placeholderState(l *sqlLexer) stateFn {
num := 0
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
if '0' <= r && r <= '9' {
num *= 10
num += int(r - '0')
} else {
l.parts = append(l.parts, num)
l.pos -= width
l.start = l.pos
return rawState
}
}
}
func escapeStringState(l *sqlLexer) stateFn {
for {
r, width := utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
switch r {
case '\\':
_, width = utf8.DecodeRuneInString(l.src[l.pos:])
l.pos += width
case '\'':
nextRune, width := utf8.DecodeRuneInString(l.src[l.pos:])
if nextRune != '\'' {
return rawState
}
l.pos += width
case utf8.RuneError:
if l.pos-l.start > 0 {
l.parts = append(l.parts, l.src[l.start:l.pos])
l.start = l.pos
}
return nil
}
}
}
// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func SanitizeSQL(sql string, args ...interface{}) (string, error) {
query, err := NewQuery(sql)
if err != nil {
return "", err
}
return query.Sanitize(args...)
}

View File

@ -0,0 +1,175 @@
package sanitize_test
import (
"testing"
"github.com/jackc/pgx/v4/internal/sanitize"
)
func TestNewQuery(t *testing.T) {
successTests := []struct {
sql string
expected sanitize.Query
}{
{
sql: "select 42",
expected: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
},
{
sql: "select $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
},
{
sql: "select 'quoted $42', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'quoted $42', ", 1}},
},
{
sql: `select "doubled quoted $42", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "doubled quoted $42", `, 1}},
},
{
sql: "select 'foo''bar', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select 'foo''bar', ", 1}},
},
{
sql: `select "foo""bar", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "foo""bar", `, 1}},
},
{
sql: "select '''', $1",
expected: sanitize.Query{Parts: []sanitize.Part{"select '''', ", 1}},
},
{
sql: `select """", $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select """", `, 1}},
},
{
sql: "select $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11",
expected: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2, ", ", 3, ", ", 4, ", ", 5, ", ", 6, ", ", 7, ", ", 8, ", ", 9, ", ", 10, ", ", 11}},
},
{
sql: `select "adsf""$1""adsf", $1, 'foo''$$12bar', $2, '$3'`,
expected: sanitize.Query{Parts: []sanitize.Part{`select "adsf""$1""adsf", `, 1, `, 'foo''$$12bar', `, 2, `, '$3'`}},
},
{
sql: `select E'escape string\' $42', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select E'escape string\' $42', `, 1}},
},
{
sql: `select e'escape string\' $42', $1`,
expected: sanitize.Query{Parts: []sanitize.Part{`select e'escape string\' $42', `, 1}},
},
}
for i, tt := range successTests {
query, err := sanitize.NewQuery(tt.sql)
if err != nil {
t.Errorf("%d. %v", i, err)
}
if len(query.Parts) == len(tt.expected.Parts) {
for j := range query.Parts {
if query.Parts[j] != tt.expected.Parts[j] {
t.Errorf("%d. expected part %d to be %v but it was %v", i, j, tt.expected.Parts[j], query.Parts[j])
}
}
} else {
t.Errorf("%d. expected query parts to be %v but it was %v", i, tt.expected.Parts, query.Parts)
}
}
}
func TestQuerySanitize(t *testing.T) {
successfulTests := []struct {
query sanitize.Query
args []interface{}
expected string
}{
{
query: sanitize.Query{Parts: []sanitize.Part{"select 42"}},
args: []interface{}{},
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{int64(42)},
expected: `select 42`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{float64(1.23)},
expected: `select 1.23`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{true},
expected: `select true`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{[]byte{0, 1, 2, 3, 255}},
expected: `select '\x00010203ff'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{nil},
expected: `select null`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{"foobar"},
expected: `select 'foobar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{"foo'bar"},
expected: `select 'foo''bar'`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{`foo\'bar`},
expected: `select 'foo\''bar'`,
},
}
for i, tt := range successfulTests {
actual, err := tt.query.Sanitize(tt.args...)
if err != nil {
t.Errorf("%d. %v", i, err)
continue
}
if tt.expected != actual {
t.Errorf("%d. expected %s, but got %s", i, tt.expected, actual)
}
}
errorTests := []struct {
query sanitize.Query
args []interface{}
expected string
}{
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1, ", ", 2}},
args: []interface{}{int64(42)},
expected: `insufficient arguments`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select 'foo'"}},
args: []interface{}{int64(42)},
expected: `unused argument: 0`,
},
{
query: sanitize.Query{Parts: []sanitize.Part{"select ", 1}},
args: []interface{}{42},
expected: `invalid arg type: int`,
},
}
for i, tt := range errorTests {
_, err := tt.query.Sanitize(tt.args...)
if err == nil || err.Error() != tt.expected {
t.Errorf("%d. expected error %v, got %v", i, tt.expected, err)
}
}
}

View File

@ -1305,3 +1305,220 @@ func TestRowsFromResultReader(t *testing.T) {
t.Error("Wrong values returned")
}
}
func TestConnSimpleProtocol(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
// Test all supported low-level types
{
expected := int64(42)
var actual int64
err := conn.QueryRow(
context.Background(),
"select $1::int8",
pgx.QuerySimpleProtocol(true),
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
{
expected := float64(1.23)
var actual float64
err := conn.QueryRow(
context.Background(),
"select $1::float8",
pgx.QuerySimpleProtocol(true),
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
{
expected := true
var actual bool
err := conn.QueryRow(
context.Background(),
"select $1",
pgx.QuerySimpleProtocol(true),
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
{
expected := []byte{0, 1, 20, 35, 64, 80, 120, 3, 255, 240, 128, 95}
var actual []byte
err := conn.QueryRow(
context.Background(),
"select $1::bytea",
pgx.QuerySimpleProtocol(true),
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if bytes.Compare(actual, expected) != 0 {
t.Errorf("expected %v got %v", expected, actual)
}
}
{
expected := "test"
var actual string
err := conn.QueryRow(
context.Background(),
"select $1::text",
pgx.QuerySimpleProtocol(true),
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
// Test high-level type
{
expected := pgtype.Circle{P: pgtype.Vec2{1, 2}, R: 1.5, Status: pgtype.Present}
actual := expected
err := conn.QueryRow(
context.Background(),
"select $1::circle",
pgx.QuerySimpleProtocol(true),
&expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
// Test multiple args in single query
{
expectedInt64 := int64(234423)
expectedFloat64 := float64(-0.2312)
expectedBool := true
expectedBytes := []byte{255, 0, 23, 16, 87, 45, 9, 23, 45, 223}
expectedString := "test"
var actualInt64 int64
var actualFloat64 float64
var actualBool bool
var actualBytes []byte
var actualString string
err := conn.QueryRow(
context.Background(),
"select $1::int8, $2::float8, $3, $4::bytea, $5::text",
pgx.QuerySimpleProtocol(true),
expectedInt64, expectedFloat64, expectedBool, expectedBytes, expectedString,
).Scan(&actualInt64, &actualFloat64, &actualBool, &actualBytes, &actualString)
if err != nil {
t.Error(err)
}
if expectedInt64 != actualInt64 {
t.Errorf("expected %v got %v", expectedInt64, actualInt64)
}
if expectedFloat64 != actualFloat64 {
t.Errorf("expected %v got %v", expectedFloat64, actualFloat64)
}
if expectedBool != actualBool {
t.Errorf("expected %v got %v", expectedBool, actualBool)
}
if bytes.Compare(expectedBytes, actualBytes) != 0 {
t.Errorf("expected %v got %v", expectedBytes, actualBytes)
}
if expectedString != actualString {
t.Errorf("expected %v got %v", expectedString, actualString)
}
}
// Test dangerous cases
{
expected := "foo';drop table users;"
var actual string
err := conn.QueryRow(
context.Background(),
"select $1",
pgx.QuerySimpleProtocol(true),
expected,
).Scan(&actual)
if err != nil {
t.Error(err)
}
if expected != actual {
t.Errorf("expected %v got %v", expected, actual)
}
}
ensureConnValid(t, conn)
}
func TestConnSimpleProtocolRefusesNonUTF8ClientEncoding(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, "set client_encoding to 'SQL_ASCII'")
var expected string
err := conn.QueryRow(
context.Background(),
"select $1",
pgx.QuerySimpleProtocol(true),
"test",
).Scan(&expected)
if err == nil {
t.Error("expected error when client_encoding not UTF8, but no error occurred")
}
ensureConnValid(t, conn)
}
func TestConnSimpleProtocolRefusesNonStandardConformingStrings(t *testing.T) {
t.Parallel()
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
defer closeConn(t, conn)
mustExec(t, conn, "set standard_conforming_strings to off")
var expected string
err := conn.QueryRow(
context.Background(),
"select $1",
pgx.QuerySimpleProtocol(true),
`\'; drop table users; --`,
).Scan(&expected)
if err == nil {
t.Error("expected error when standard_conforming_strings is off, but no error occurred")
}
ensureConnValid(t, conn)
}

10
rows.go
View File

@ -86,7 +86,8 @@ type connRows struct {
args []interface{}
closed bool
resultReader *pgconn.ResultReader
resultReader *pgconn.ResultReader
multiResultReader *pgconn.MultiResultReader
}
func (rows *connRows) FieldDescriptions() []pgproto3.FieldDescription {
@ -107,6 +108,13 @@ func (rows *connRows) Close() {
}
}
if rows.multiResultReader != nil {
closeErr := rows.multiResultReader.Close()
if rows.err == nil {
rows.err = closeErr
}
}
if rows.logger != nil {
if rows.err == nil {
if rows.logger.shouldLog(LogLevelInfo) {

View File

@ -1054,6 +1054,83 @@ func TestRowsColumnTypes(t *testing.T) {
}
}
func TestSimpleQueryLifeCycle(t *testing.T) {
config, err := pgx.ParseConfig("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")
if err != nil {
t.Fatalf("pgx.ParseConnectionString failed: %v", err)
}
config.PreferSimpleProtocol = true
db := stdlib.OpenDB(*config)
defer closeDB(t, db)
rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
if err != nil {
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
rowCount := int64(0)
for rows.Next() {
rowCount++
var (
s string
n int64
)
if err := rows.Scan(&s, &n); err != nil {
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
}
if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s)
}
if n != rowCount {
t.Errorf("Expected %d, received %d", rowCount, n)
}
}
if err = rows.Err(); err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 10 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
rows, err = db.Query("select 1 where false")
if err != nil {
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
rowCount = int64(0)
for rows.Next() {
rowCount++
}
if err = rows.Err(); err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 0 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}
// https://github.com/jackc/pgx/issues/409
func TestScanJSONIntoJSONRawMessage(t *testing.T) {
db := openDB(t)