mirror of https://github.com/jackc/pgx.git
Add prefer simple protocol support to stdlib
Test code partially taken from james-lawrence
(7471e7f9eb
)
pull/388/head
parent
bd76a96882
commit
3707b79782
|
@ -158,7 +158,7 @@ func (d *Driver) Open(name string) (driver.Conn, error) {
|
|||
}
|
||||
}
|
||||
|
||||
c := &Conn{conn: conn, driver: d}
|
||||
c := &Conn{conn: conn, driver: d, connConfig: connConfig}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
|
@ -210,9 +210,10 @@ func UnregisterDriverConfig(c *DriverConfig) {
|
|||
}
|
||||
|
||||
type Conn struct {
|
||||
conn *pgx.Conn
|
||||
psCount int64 // Counter used for creating unique prepared statement names
|
||||
driver *Driver
|
||||
conn *pgx.Conn
|
||||
psCount int64 // Counter used for creating unique prepared statement names
|
||||
driver *Driver
|
||||
connConfig pgx.ConnConfig
|
||||
}
|
||||
|
||||
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
|
||||
|
@ -303,14 +304,24 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
|
|||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
ps, err := c.conn.Prepare("", query)
|
||||
if !c.connConfig.PreferSimpleProtocol {
|
||||
ps, err := c.conn.Prepare("", query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
return c.queryPrepared("", argsV)
|
||||
}
|
||||
|
||||
rows, err := c.conn.Query(query, valueToInterface(argsV)...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
|
||||
return c.queryPrepared("", argsV)
|
||||
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
|
||||
more := rows.Next()
|
||||
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
|
||||
|
@ -318,14 +329,24 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na
|
|||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
||||
ps, err := c.conn.PrepareEx(ctx, "", query, nil)
|
||||
if !c.connConfig.PreferSimpleProtocol {
|
||||
ps, err := c.conn.PrepareEx(ctx, "", query, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
return c.queryPreparedContext(ctx, "", argsV)
|
||||
}
|
||||
|
||||
rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
restrictBinaryToDatabaseSqlTypes(ps)
|
||||
|
||||
return c.queryPreparedContext(ctx, "", argsV)
|
||||
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
|
||||
more := rows.Next()
|
||||
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
|
||||
}
|
||||
|
||||
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
|
||||
|
@ -408,8 +429,10 @@ func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (dri
|
|||
}
|
||||
|
||||
type Rows struct {
|
||||
rows *pgx.Rows
|
||||
values []interface{}
|
||||
rows *pgx.Rows
|
||||
values []interface{}
|
||||
skipNext bool
|
||||
skipNextMore bool
|
||||
}
|
||||
|
||||
func (r *Rows) Columns() []string {
|
||||
|
@ -486,7 +509,14 @@ func (r *Rows) Next(dest []driver.Value) error {
|
|||
}
|
||||
}
|
||||
|
||||
more := r.rows.Next()
|
||||
var more bool
|
||||
if r.skipNext {
|
||||
more = r.skipNextMore
|
||||
r.skipNext = false
|
||||
} else {
|
||||
more = r.rows.Next()
|
||||
}
|
||||
|
||||
if !more {
|
||||
if r.rows.Err() == nil {
|
||||
return io.EOF
|
||||
|
|
|
@ -1385,3 +1385,84 @@ func TestRowsColumnTypes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleQueryLifeCycle(t *testing.T) {
|
||||
driverConfig := stdlib.DriverConfig{
|
||||
ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true},
|
||||
}
|
||||
|
||||
stdlib.RegisterDriverConfig(&driverConfig)
|
||||
defer stdlib.UnregisterDriverConfig(&driverConfig)
|
||||
|
||||
db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
|
||||
if err != nil {
|
||||
t.Fatalf("sql.Open failed: %v", err)
|
||||
}
|
||||
defer closeDB(t, db)
|
||||
|
||||
rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
rowCount := int64(0)
|
||||
|
||||
for rows.Next() {
|
||||
rowCount++
|
||||
var (
|
||||
s string
|
||||
n int64
|
||||
)
|
||||
|
||||
if err := rows.Scan(&s, &n); err != nil {
|
||||
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
if s != "foo" {
|
||||
t.Errorf(`Expected "foo", received "%v"`, s)
|
||||
}
|
||||
|
||||
if n != rowCount {
|
||||
t.Errorf("Expected %d, received %d", rowCount, n)
|
||||
}
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||
}
|
||||
|
||||
if rowCount != 10 {
|
||||
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||
}
|
||||
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
rows, err = db.Query("select 1 where false")
|
||||
if err != nil {
|
||||
t.Fatalf("stmt.Query unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
rowCount = int64(0)
|
||||
|
||||
for rows.Next() {
|
||||
rowCount++
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||
}
|
||||
|
||||
if rowCount != 0 {
|
||||
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||
}
|
||||
|
||||
err = rows.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, db)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue