Add prefer simple protocol support to stdlib

Test code partially taken from james-lawrence
(7471e7f9eb)
pull/388/head
Jack Christensen 2018-01-13 13:40:04 -06:00
parent bd76a96882
commit 3707b79782
2 changed files with 126 additions and 15 deletions

View File

@ -158,7 +158,7 @@ func (d *Driver) Open(name string) (driver.Conn, error) {
}
}
c := &Conn{conn: conn, driver: d}
c := &Conn{conn: conn, driver: d, connConfig: connConfig}
return c, nil
}
@ -210,9 +210,10 @@ func UnregisterDriverConfig(c *DriverConfig) {
}
type Conn struct {
conn *pgx.Conn
psCount int64 // Counter used for creating unique prepared statement names
driver *Driver
conn *pgx.Conn
psCount int64 // Counter used for creating unique prepared statement names
driver *Driver
connConfig pgx.ConnConfig
}
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
@ -303,14 +304,24 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
return nil, driver.ErrBadConn
}
ps, err := c.conn.Prepare("", query)
if !c.connConfig.PreferSimpleProtocol {
ps, err := c.conn.Prepare("", query)
if err != nil {
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPrepared("", argsV)
}
rows, err := c.conn.Query(query, valueToInterface(argsV)...)
if err != nil {
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPrepared("", argsV)
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
more := rows.Next()
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
}
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
@ -318,14 +329,24 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
return nil, driver.ErrBadConn
}
ps, err := c.conn.PrepareEx(ctx, "", query, nil)
if !c.connConfig.PreferSimpleProtocol {
ps, err := c.conn.PrepareEx(ctx, "", query, nil)
if err != nil {
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPreparedContext(ctx, "", argsV)
}
rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...)
if err != nil {
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPreparedContext(ctx, "", argsV)
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
more := rows.Next()
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
}
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
@ -408,8 +429,10 @@ func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (dri
}
type Rows struct {
rows *pgx.Rows
values []interface{}
rows *pgx.Rows
values []interface{}
skipNext bool
skipNextMore bool
}
func (r *Rows) Columns() []string {
@ -486,7 +509,14 @@ func (r *Rows) Next(dest []driver.Value) error {
}
}
more := r.rows.Next()
var more bool
if r.skipNext {
more = r.skipNextMore
r.skipNext = false
} else {
more = r.rows.Next()
}
if !more {
if r.rows.Err() == nil {
return io.EOF

View File

@ -1385,3 +1385,84 @@ func TestRowsColumnTypes(t *testing.T) {
}
}
}
func TestSimpleQueryLifeCycle(t *testing.T) {
driverConfig := stdlib.DriverConfig{
ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true},
}
stdlib.RegisterDriverConfig(&driverConfig)
defer stdlib.UnregisterDriverConfig(&driverConfig)
db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
if err != nil {
t.Fatalf("sql.Open failed: %v", err)
}
defer closeDB(t, db)
rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
if err != nil {
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
rowCount := int64(0)
for rows.Next() {
rowCount++
var (
s string
n int64
)
if err := rows.Scan(&s, &n); err != nil {
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
}
if s != "foo" {
t.Errorf(`Expected "foo", received "%v"`, s)
}
if n != rowCount {
t.Errorf("Expected %d, received %d", rowCount, n)
}
}
if err = rows.Err(); err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 10 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
rows, err = db.Query("select 1 where false")
if err != nil {
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
}
rowCount = int64(0)
for rows.Next() {
rowCount++
}
if err = rows.Err(); err != nil {
t.Fatalf("rows.Err unexpectedly is: %v", err)
}
if rowCount != 0 {
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
}
err = rows.Close()
if err != nil {
t.Fatalf("rows.Close unexpectedly failed: %v", err)
}
ensureConnValid(t, db)
}