From 3707b79782aabab12c5eda8b51b692bbc873a1fe Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 13 Jan 2018 13:40:04 -0600 Subject: [PATCH] Add prefer simple protocol support to stdlib Test code partially taken from james-lawrence (7471e7f9eb0f33a05e2f0cf06db8714850880d72) --- stdlib/sql.go | 60 +++++++++++++++++++++++++--------- stdlib/sql_test.go | 81 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 15 deletions(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index 5dec8830..2d4930ee 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -158,7 +158,7 @@ func (d *Driver) Open(name string) (driver.Conn, error) { } } - c := &Conn{conn: conn, driver: d} + c := &Conn{conn: conn, driver: d, connConfig: connConfig} return c, nil } @@ -210,9 +210,10 @@ func UnregisterDriverConfig(c *DriverConfig) { } type Conn struct { - conn *pgx.Conn - psCount int64 // Counter used for creating unique prepared statement names - driver *Driver + conn *pgx.Conn + psCount int64 // Counter used for creating unique prepared statement names + driver *Driver + connConfig pgx.ConnConfig } func (c *Conn) Prepare(query string) (driver.Stmt, error) { @@ -303,14 +304,24 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { return nil, driver.ErrBadConn } - ps, err := c.conn.Prepare("", query) + if !c.connConfig.PreferSimpleProtocol { + ps, err := c.conn.Prepare("", query) + if err != nil { + return nil, err + } + + restrictBinaryToDatabaseSqlTypes(ps) + return c.queryPrepared("", argsV) + } + + rows, err := c.conn.Query(query, valueToInterface(argsV)...) if err != nil { return nil, err } - restrictBinaryToDatabaseSqlTypes(ps) - - return c.queryPrepared("", argsV) + // Preload first row because otherwise we won't know what columns are available when database/sql asks. + more := rows.Next() + return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil } func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) { @@ -318,14 +329,24 @@ func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.Na return nil, driver.ErrBadConn } - ps, err := c.conn.PrepareEx(ctx, "", query, nil) + if !c.connConfig.PreferSimpleProtocol { + ps, err := c.conn.PrepareEx(ctx, "", query, nil) + if err != nil { + return nil, err + } + + restrictBinaryToDatabaseSqlTypes(ps) + return c.queryPreparedContext(ctx, "", argsV) + } + + rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...) if err != nil { return nil, err } - restrictBinaryToDatabaseSqlTypes(ps) - - return c.queryPreparedContext(ctx, "", argsV) + // Preload first row because otherwise we won't know what columns are available when database/sql asks. + more := rows.Next() + return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil } func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) { @@ -408,8 +429,10 @@ func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (dri } type Rows struct { - rows *pgx.Rows - values []interface{} + rows *pgx.Rows + values []interface{} + skipNext bool + skipNextMore bool } func (r *Rows) Columns() []string { @@ -486,7 +509,14 @@ func (r *Rows) Next(dest []driver.Value) error { } } - more := r.rows.Next() + var more bool + if r.skipNext { + more = r.skipNextMore + r.skipNext = false + } else { + more = r.rows.Next() + } + if !more { if r.rows.Err() == nil { return io.EOF diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index f2310728..a4a99971 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -1385,3 +1385,84 @@ func TestRowsColumnTypes(t *testing.T) { } } } + +func TestSimpleQueryLifeCycle(t *testing.T) { + driverConfig := stdlib.DriverConfig{ + ConnConfig: pgx.ConnConfig{PreferSimpleProtocol: true}, + } + + stdlib.RegisterDriverConfig(&driverConfig) + defer stdlib.UnregisterDriverConfig(&driverConfig) + + db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + defer closeDB(t, db) + + rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3) + if err != nil { + t.Fatalf("stmt.Query unexpectedly failed: %v", err) + } + + rowCount := int64(0) + + for rows.Next() { + rowCount++ + var ( + s string + n int64 + ) + + if err := rows.Scan(&s, &n); err != nil { + t.Fatalf("rows.Scan unexpectedly failed: %v", err) + } + + if s != "foo" { + t.Errorf(`Expected "foo", received "%v"`, s) + } + + if n != rowCount { + t.Errorf("Expected %d, received %d", rowCount, n) + } + } + + if err = rows.Err(); err != nil { + t.Fatalf("rows.Err unexpectedly is: %v", err) + } + + if rowCount != 10 { + t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) + } + + err = rows.Close() + if err != nil { + t.Fatalf("rows.Close unexpectedly failed: %v", err) + } + + rows, err = db.Query("select 1 where false") + if err != nil { + t.Fatalf("stmt.Query unexpectedly failed: %v", err) + } + + rowCount = int64(0) + + for rows.Next() { + rowCount++ + } + + if err = rows.Err(); err != nil { + t.Fatalf("rows.Err unexpectedly is: %v", err) + } + + if rowCount != 0 { + t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) + } + + err = rows.Close() + if err != nil { + t.Fatalf("rows.Close unexpectedly failed: %v", err) + } + + ensureConnValid(t, db) +}