From 78adfb13d796427deafa89fc45aa5c7e47f8d51b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 14:20:00 -0600 Subject: [PATCH 01/23] Add Ping, PingContext, and ExecContext --- conn.go | 96 ++++++++++++++++++++++++++++++++++++++++++++++------ conn_test.go | 68 +++++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 11 deletions(-) diff --git a/conn.go b/conn.go index 602ecbff..645b9c5d 100644 --- a/conn.go +++ b/conn.go @@ -8,6 +8,7 @@ import ( "encoding/hex" "errors" "fmt" + "golang.org/x/net/context" "io" "net" "net/url" @@ -39,6 +40,22 @@ type ConnConfig struct { RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) } +func (cc *ConnConfig) networkAddress() (network, address string) { + network = "tcp" + address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) + // See if host is a valid path, if yes connect with a socket + if _, err := os.Stat(cc.Host); err == nil { + // For backward compatibility accept socket file paths -- but directories are now preferred + network = "unix" + address = cc.Host + if !strings.Contains(address, "/.s.PGSQL.") { + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) + } + } + + return network, address +} + // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. // Use ConnPool to manage access to multiple database connections from multiple // goroutines. @@ -194,17 +211,7 @@ func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsql } } - network := "tcp" - address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) - // See if host is a valid path, if yes connect with a socket - if _, err := os.Stat(c.config.Host); err == nil { - // For backward compatibility accept socket file paths -- but directories are now preferred - network = "unix" - address = c.config.Host - if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10) - } - } + network, address := c.config.networkAddress() if c.config.Dial == nil { c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial } @@ -1292,3 +1299,70 @@ func (c *Conn) SetLogLevel(lvl int) (int, error) { func quoteIdentifier(s string) string { return `"` + strings.Replace(s, `"`, `""`, -1) + `"` } + +// cancelQuery sends a cancel request to the PostgreSQL server. It returns an +// error if unable to deliver the cancel request, but lack of an error does not +// ensure that the query was canceled. As specified in the documentation, there +// is no way to be sure a query was canceled. See +// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 +func (c *Conn) cancelQuery() error { + network, address := c.config.networkAddress() + cancelConn, err := c.config.Dial(network, address) + if err != nil { + return err + } + defer cancelConn.Close() + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid)) + binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey)) + _, err = cancelConn.Write(buf) + return err +} + +func (c *Conn) Ping() error { + _, err := c.Exec(";") + return err +} + +func (c *Conn) PingContext(ctx context.Context) error { + _, err := c.ExecContext(ctx, ";") + return err +} + +func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + select { + case <-ctx.Done(): + return "", ctx.Err() + default: + } + + doneChan := make(chan struct{}) + closedChan := make(chan bool) + + go func() { + select { + case <-ctx.Done(): + c.cancelQuery() + c.Close() + <-doneChan + closedChan <- true + case <-doneChan: + closedChan <- false + } + }() + + commandTag, err = c.Exec(sql, arguments...) + + // Signal cancelation goroutine that operation is done + doneChan <- struct{}{} + + // If c was closed due to context cancelation then return context err + if <-closedChan { + return "", ctx.Err() + } + + return commandTag, err +} diff --git a/conn_test.go b/conn_test.go index 9ed073ce..a9cf02c9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "crypto/tls" "fmt" + "golang.org/x/net/context" "net" "os" "reflect" @@ -816,6 +817,73 @@ func TestExecFailure(t *testing.T) { } } +func TestExecContextWithoutCancelation(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + commandTag, err := conn.ExecContext(ctx, "create temporary table foo(id integer primary key);") + if err != nil { + t.Fatal(err) + } + if commandTag != "CREATE TABLE" { + t.Fatalf("Unexpected results from ExecContext: %v", commandTag) + } +} + +func TestExecContextFailureWithoutCancelation(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + if _, err := conn.ExecContext(ctx, "selct;"); err == nil { + t.Fatal("Expected SQL syntax error") + } + + rows, _ := conn.Query("select 1") + rows.Close() + if rows.Err() != nil { + t.Fatalf("ExecContext failure appears to have broken connection: %v", rows.Err()) + } +} + +func TestExecContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + _, err := conn.ExecContext(ctx, "select pg_sleep(60)") + if err != context.Canceled { + t.Fatal("Expected context.Canceled err, got %v", err) + } + + time.Sleep(500 * time.Millisecond) + + checkConn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, checkConn) + + var found bool + err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) + if err != pgx.ErrNoRows { + t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") + } +} + func TestPrepare(t *testing.T) { t.Parallel() From 3e13b333d9d3e2fa14f8e7e43ae041dcd6602433 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 15:40:58 -0600 Subject: [PATCH 02/23] Add QueryContext --- query.go | 48 ++++++++++++++++++++++ query_test.go | 111 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 159 insertions(+) diff --git a/query.go b/query.go index 19b867e2..121dcfe3 100644 --- a/query.go +++ b/query.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "golang.org/x/net/context" "time" ) @@ -49,6 +50,9 @@ type Rows struct { afterClose func(*Rows) unlockConn bool closed bool + + ctx context.Context + doneChan chan struct{} } func (rows *Rows) FieldDescriptions() []FieldDescription { @@ -120,6 +124,15 @@ func (rows *Rows) Close() { return } rows.readUntilReadyForQuery() + + if rows.ctx != nil { + select { + case <-rows.ctx.Done(): + rows.err = rows.ctx.Err() + case rows.doneChan <- struct{}{}: + } + } + rows.close() } @@ -492,3 +505,38 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { rows, _ := c.Query(sql, args...) return (*Row)(rows) } + +func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + doneChan := make(chan struct{}) + + go func() { + select { + case <-ctx.Done(): + c.cancelQuery() + c.Close() + case <-doneChan: + } + }() + + rows, err := c.Query(sql, args...) + + if err != nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case doneChan <- struct{}{}: + return nil, err + } + } + + rows.ctx = ctx + rows.doneChan = doneChan + + return rows, nil +} diff --git a/query_test.go b/query_test.go index f08887b5..ca05fb42 100644 --- a/query_test.go +++ b/query_test.go @@ -4,6 +4,7 @@ import ( "bytes" "database/sql" "fmt" + "golang.org/x/net/context" "strings" "testing" "time" @@ -1412,3 +1413,113 @@ func TestConnQueryDatabaseSQLNullX(t *testing.T) { ensureConnValid(t, conn) } + +func TestQueryContextSuccess(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + rows, err := conn.QueryContext(ctx, "select 42::integer") + if err != nil { + t.Fatal(err) + } + + var result, rowCount int + for rows.Next() { + err = rows.Scan(&result) + if err != nil { + t.Fatal(err) + } + rowCount++ + } + + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + + if rowCount != 1 { + t.Fatalf("Expected 1 row, got %d", rowCount) + } + if result != 42 { + t.Fatalf("Expected result 42, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryContextErrorWhileReceivingRows(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + rows, err := conn.QueryContext(ctx, "select 10/(10-n) from generate_series(1, 100) n") + if err != nil { + t.Fatal(err) + } + + var result, rowCount int + for rows.Next() { + err = rows.Scan(&result) + if err != nil { + t.Fatal(err) + } + rowCount++ + } + + if rows.Err() == nil || rows.Err().Error() != "ERROR: division by zero (SQLSTATE 22012)" { + t.Fatalf("Expected division by zero error, but got %v", rows.Err()) + } + + if rowCount != 9 { + t.Fatalf("Expected 9 rows, got %d", rowCount) + } + if result != 10 { + t.Fatalf("Expected result 10, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + rows, err := conn.QueryContext(ctx, "select pg_sleep(5)") + if err != nil { + t.Fatal(err) + } + + for rows.Next() { + t.Fatal("No rows should ever be ready -- context cancel apparently did not happen") + } + + if rows.Err() != context.Canceled { + t.Fatal("Expected context.Canceled error, got %v", rows.Err()) + } + + checkConn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, checkConn) + + var found bool + err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) + if err != pgx.ErrNoRows { + t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") + } + +} From 24193ee3223581d6593d5de2364f72839c73b5ba Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 15:57:06 -0600 Subject: [PATCH 03/23] Add QueryRowContext --- query.go | 15 ++++++------ query_test.go | 68 ++++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 9 deletions(-) diff --git a/query.go b/query.go index 121dcfe3..fc3f405b 100644 --- a/query.go +++ b/query.go @@ -507,12 +507,6 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - doneChan := make(chan struct{}) go func() { @@ -529,9 +523,9 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{} if err != nil { select { case <-ctx.Done(): - return nil, ctx.Err() + return rows, ctx.Err() case doneChan <- struct{}{}: - return nil, err + return rows, err } } @@ -540,3 +534,8 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{} return rows, nil } + +func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row { + rows, _ := c.QueryContext(ctx, sql, args...) + return (*Row)(rows) +} diff --git a/query_test.go b/query_test.go index ca05fb42..6909ba1e 100644 --- a/query_test.go +++ b/query_test.go @@ -1521,5 +1521,71 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) { if err != pgx.ErrNoRows { t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") } - +} + +func TestQueryRowContextSuccess(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + var result int + err := conn.QueryRowContext(ctx, "select 42::integer").Scan(&result) + if err != nil { + t.Fatal(err) + } + if result != 42 { + t.Fatalf("Expected result 42, got %d", result) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowContextErrorWhileReceivingRow(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + var result int + err := conn.QueryRowContext(ctx, "select 10/0").Scan(&result) + if err == nil || err.Error() != "ERROR: division by zero (SQLSTATE 22012)" { + t.Fatalf("Expected division by zero error, but got %v", err) + } + + ensureConnValid(t, conn) +} + +func TestQueryRowContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + var result []byte + err := conn.QueryRowContext(ctx, "select pg_sleep(5)").Scan(&result) + if err != context.Canceled { + t.Fatal("Expected context.Canceled error, got %v", err) + } + + checkConn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, checkConn) + + var found bool + err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) + if err != pgx.ErrNoRows { + t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") + } } From a9e7e3acbc04145211116a11959c4db176a5df9a Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 16:03:20 -0600 Subject: [PATCH 04/23] Extract connection dead on server test --- conn_test.go | 11 +---------- helper_test.go | 17 ++++++++++++++++- query_test.go | 18 ++---------------- 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/conn_test.go b/conn_test.go index a9cf02c9..e92c7ca3 100644 --- a/conn_test.go +++ b/conn_test.go @@ -872,16 +872,7 @@ func TestExecContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled err, got %v", err) } - time.Sleep(500 * time.Millisecond) - - checkConn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, checkConn) - - var found bool - err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) - if err != pgx.ErrNoRows { - t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") - } + ensureConnDeadOnServer(t, conn, *defaultConnConfig) } func TestPrepare(t *testing.T) { diff --git a/helper_test.go b/helper_test.go index eff731e8..997ae26f 100644 --- a/helper_test.go +++ b/helper_test.go @@ -21,7 +21,6 @@ func mustReplicationConnect(t testing.TB, config pgx.ConnConfig) *pgx.Replicatio return conn } - func closeConn(t testing.TB, conn *pgx.Conn) { err := conn.Close() if err != nil { @@ -72,3 +71,19 @@ func ensureConnValid(t *testing.T, conn *pgx.Conn) { t.Error("Wrong values returned") } } + +func ensureConnDeadOnServer(t *testing.T, conn *pgx.Conn, config pgx.ConnConfig) { + checkConn := mustConnect(t, config) + defer closeConn(t, checkConn) + + for i := 0; i < 10; i++ { + var found bool + err := checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) + if err == pgx.ErrNoRows { + return + } else if err != nil { + t.Fatalf("Unable to check if conn is dead on server: %v", err) + } + } + t.Fatal("Expected conn to be disconnected from server, but it wasn't") +} diff --git a/query_test.go b/query_test.go index 6909ba1e..40886f2e 100644 --- a/query_test.go +++ b/query_test.go @@ -1513,14 +1513,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled error, got %v", rows.Err()) } - checkConn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, checkConn) - - var found bool - err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) - if err != pgx.ErrNoRows { - t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") - } + ensureConnDeadOnServer(t, conn, *defaultConnConfig) } func TestQueryRowContextSuccess(t *testing.T) { @@ -1580,12 +1573,5 @@ func TestQueryRowContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled error, got %v", err) } - checkConn := mustConnect(t, *defaultConnConfig) - defer closeConn(t, checkConn) - - var found bool - err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) - if err != pgx.ErrNoRows { - t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") - } + ensureConnDeadOnServer(t, conn, *defaultConnConfig) } From 94eea5128e3eb1f37f9b70771b0d9a68545839b5 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 18:09:25 -0600 Subject: [PATCH 05/23] Add context dependency to travis --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index d9ea43b0..4a3b91e2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -51,6 +51,7 @@ install: - go get -u github.com/shopspring/decimal - go get -u gopkg.in/inconshreveable/log15.v2 - go get -u github.com/jackc/fake + - go get -u golang.org/x/net/context script: - go test -v -race -short ./... From 37b86083e4361243246805ceae845e36c9692e9b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 18:44:55 -0600 Subject: [PATCH 06/23] Fix race condition with canceled contexts --- conn.go | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/conn.go b/conn.go index 645b9c5d..45bb9441 100644 --- a/conn.go +++ b/conn.go @@ -18,6 +18,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" ) @@ -74,8 +75,6 @@ type Conn struct { preparedStatements map[string]*PreparedStatement channels map[string]struct{} notifications []*Notification - alive bool - causeOfDeath error logger Logger logLevel int mr msgReader @@ -85,6 +84,10 @@ type Conn struct { busy bool poolResetCount int preallocatedRows []Rows + + closingLock sync.Mutex + alive bool + causeOfDeath error } // PreparedStatement is a description of a prepared statement @@ -391,14 +394,14 @@ func (c *Conn) loadInetConstants() error { // Close closes a connection. It is safe to call Close on a already closed // connection. func (c *Conn) Close() (err error) { - if !c.IsAlive() { + c.closingLock.Lock() + defer c.closingLock.Unlock() + + if !c.alive { return nil } - wbuf := newWriteBuf(c, 'X') - wbuf.closeMsg() - - _, err = c.conn.Write(wbuf.buf) + _, err = c.conn.Write([]byte{'X', 0, 0, 0, 4}) c.die(errors.New("Closed")) if c.shouldLog(LogLevelInfo) { @@ -870,7 +873,10 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { } func (c *Conn) IsAlive() bool { - return c.alive + c.closingLock.Lock() + alive := c.alive + c.closingLock.Unlock() + return alive } func (c *Conn) CauseOfDeath() error { From 14eedb4fcaa7eec18725aeb692346a1d2e883b30 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 21:10:13 -0600 Subject: [PATCH 07/23] Add ConnPool context methods --- conn.go | 3 +++ conn_pool.go | 34 ++++++++++++++++++++++++++++++++++ context-todo.txt | 12 ++++++++++++ query.go | 12 ++++++++---- stress_test.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 100 insertions(+), 4 deletions(-) create mode 100644 context-todo.txt diff --git a/conn.go b/conn.go index 45bb9441..f7c06014 100644 --- a/conn.go +++ b/conn.go @@ -1051,9 +1051,12 @@ func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { } func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { + c.closingLock.Lock() if !c.alive { + c.closingLock.Unlock() return 0, nil, ErrDeadConn } + c.closingLock.Unlock() t, err = c.mr.rxMsg() if err != nil { diff --git a/conn_pool.go b/conn_pool.go index 6d04565d..50b9d588 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -2,6 +2,7 @@ package pgx import ( "errors" + "golang.org/x/net/context" "sync" "time" ) @@ -357,6 +358,16 @@ func (p *ConnPool) Exec(sql string, arguments ...interface{}) (commandTag Comman return c.Exec(sql, arguments...) } +func (p *ConnPool) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + var c *Conn + if c, err = p.Acquire(); err != nil { + return + } + defer p.Release(c) + + return c.ExecContext(ctx, sql, arguments...) +} + // Query acquires a connection and delegates the call to that connection. When // *Rows are closed, the connection is released automatically. func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { @@ -377,6 +388,24 @@ func (p *ConnPool) Query(sql string, args ...interface{}) (*Rows, error) { return rows, nil } +func (p *ConnPool) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { + c, err := p.Acquire() + if err != nil { + // Because checking for errors can be deferred to the *Rows, build one with the error + return &Rows{closed: true, err: err}, err + } + + rows, err := c.QueryContext(ctx, sql, args...) + if err != nil { + p.Release(c) + return rows, err + } + + rows.AfterClose(p.rowsAfterClose) + + return rows, nil +} + // QueryRow acquires a connection and delegates the call to that connection. The // connection is released automatically after Scan is called on the returned // *Row. @@ -385,6 +414,11 @@ func (p *ConnPool) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } +func (p *ConnPool) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row { + rows, _ := p.QueryContext(ctx, sql, args...) + return (*Row)(rows) +} + // Begin acquires a connection and begins a transaction on it. When the // transaction is closed the connection will be automatically released. func (p *ConnPool) Begin() (*Tx, error) { diff --git a/context-todo.txt b/context-todo.txt new file mode 100644 index 00000000..b5a20d0a --- /dev/null +++ b/context-todo.txt @@ -0,0 +1,12 @@ +Add more testing +- stress test style +- pgmock + +Add documentation + +Add PrepareContext +Add context methods to ConnPool +Add context methods to Tx +Add context support database/sql + +Benchmark - possibly cache done channel on Conn diff --git a/query.go b/query.go index fc3f405b..3ded881d 100644 --- a/query.go +++ b/query.go @@ -51,8 +51,9 @@ type Rows struct { unlockConn bool closed bool - ctx context.Context - doneChan chan struct{} + ctx context.Context + doneChan chan struct{} + closedChan chan bool } func (rows *Rows) FieldDescriptions() []FieldDescription { @@ -127,7 +128,7 @@ func (rows *Rows) Close() { if rows.ctx != nil { select { - case <-rows.ctx.Done(): + case <-rows.closedChan: rows.err = rows.ctx.Err() case rows.doneChan <- struct{}{}: } @@ -508,12 +509,14 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { doneChan := make(chan struct{}) + closedChan := make(chan bool) go func() { select { case <-ctx.Done(): c.cancelQuery() c.Close() + closedChan <- true case <-doneChan: } }() @@ -522,7 +525,7 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{} if err != nil { select { - case <-ctx.Done(): + case <-closedChan: return rows, ctx.Err() case doneChan <- struct{}{}: return rows, err @@ -531,6 +534,7 @@ func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{} rows.ctx = ctx rows.doneChan = doneChan + rows.closedChan = closedChan return rows, nil } diff --git a/stress_test.go b/stress_test.go index 150d13c8..d22d9d6b 100644 --- a/stress_test.go +++ b/stress_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "errors" "fmt" + "golang.org/x/net/context" "math/rand" "testing" "time" @@ -44,6 +45,8 @@ func TestStressConnPool(t *testing.T) { {"listenAndPoolUnlistens", listenAndPoolUnlistens}, {"reset", func(p *pgx.ConnPool, n int) error { p.Reset(); return nil }}, {"poolPrepareUseAndDeallocate", poolPrepareUseAndDeallocate}, + {"canceledQueryContext", canceledQueryContext}, + {"canceledExecContext", canceledExecContext}, } var timer *time.Timer @@ -344,3 +347,43 @@ func txMultipleQueries(pool *pgx.ConnPool, actionNum int) error { return tx.Commit() } + +func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error { + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) + cancelFunc() + }() + + rows, err := pool.QueryContext(ctx, "select pg_sleep(5)") + if err == context.Canceled { + return nil + } else if err != nil { + return fmt.Errorf("canceledQueryContext: Only allowed error is context.Canceled, got %v", err) + } + + for rows.Next() { + return errors.New("canceledQueryContext: should never receive row") + } + + if rows.Err() != context.Canceled { + return fmt.Errorf("canceledQueryContext: Expected context.Canceled error, got %v", rows.Err()) + } + + return nil +} + +func canceledExecContext(pool *pgx.ConnPool, actionNum int) error { + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(time.Duration(rand.Intn(50)) * time.Millisecond) + cancelFunc() + }() + + _, err := pool.ExecContext(ctx, "select pg_sleep(5)") + if err != context.Canceled { + return fmt.Errorf("canceledExecContext: Expected context.Canceled error, got %v", err) + } + + return nil +} From 351eb8ba679c66de3a67db7da9e0cd06f6fecda8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 6 Feb 2017 19:39:34 -0600 Subject: [PATCH 08/23] Initial proof-of-concept database/sql context support --- conn.go | 52 ++++++++++++++++++++++++++++++++++++++++----------- stdlib/sql.go | 46 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 11 deletions(-) diff --git a/conn.go b/conn.go index f7c06014..b8131716 100644 --- a/conn.go +++ b/conn.go @@ -619,6 +619,41 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { + return c.PrepareExContext(context.Background(), name, sql, opts) + +} + +func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + doneChan := make(chan struct{}) + closedChan := make(chan struct{}) + + go func() { + select { + case <-ctx.Done(): + c.cancelQuery() + c.Close() + closedChan <- struct{}{} + case <-doneChan: + } + }() + + ps, err = c.prepareEx(name, sql, opts) + + select { + case <-closedChan: + return nil, ctx.Err() + case doneChan <- struct{}{}: + return ps, err + } +} + +func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { if name != "" { if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql { return ps, nil @@ -1349,29 +1384,24 @@ func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interfa } doneChan := make(chan struct{}) - closedChan := make(chan bool) + closedChan := make(chan struct{}) go func() { select { case <-ctx.Done(): c.cancelQuery() c.Close() - <-doneChan - closedChan <- true + closedChan <- struct{}{} case <-doneChan: - closedChan <- false } }() commandTag, err = c.Exec(sql, arguments...) - // Signal cancelation goroutine that operation is done - doneChan <- struct{}{} - - // If c was closed due to context cancelation then return context err - if <-closedChan { + select { + case <-closedChan: return "", ctx.Err() + case doneChan <- struct{}{}: + return commandTag, err } - - return commandTag, err } diff --git a/stdlib/sql.go b/stdlib/sql.go index 610aefd4..74218a7b 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -44,6 +44,7 @@ package stdlib import ( + "context" "database/sql" "database/sql/driver" "errors" @@ -211,6 +212,21 @@ func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { return c.queryPrepared("", argsV) } +func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + ps, err := c.conn.PrepareExContext(ctx, "", query, nil) + if err != nil { + return nil, err + } + + restrictBinaryToDatabaseSqlTypes(ps) + + return c.queryPreparedContext(ctx, "", argsV) +} + func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn @@ -226,6 +242,24 @@ func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, er return &Rows{rows: rows}, nil } +func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + args := namedValueToInterface(argsV) + + rows, err := c.conn.QueryContext(ctx, name, args...) + if err != nil { + fmt.Println(err) + return nil, err + } + + fmt.Println("ere") + + return &Rows{rows: rows}, nil +} + // Anything that isn't a database/sql compatible type needs to be forced to // text format so that pgx.Rows.Values doesn't decode it into a native type // (e.g. []int32) @@ -318,6 +352,18 @@ func valueToInterface(argsV []driver.Value) []interface{} { return args } +func namedValueToInterface(argsV []driver.NamedValue) []interface{} { + args := make([]interface{}, 0, len(argsV)) + for _, v := range argsV { + if v.Value != nil { + args = append(args, v.Value.(interface{})) + } else { + args = append(args, nil) + } + } + return args +} + type Tx struct { conn *pgx.Conn } From 004c18e5a21c7837cb6dc578f22471115b29fdc8 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 Feb 2017 20:35:37 -0600 Subject: [PATCH 09/23] Begin extracting context handling --- conn.go | 53 +++++++++++++++++++++++------------------------------ query.go | 27 ++++++--------------------- 2 files changed, 29 insertions(+), 51 deletions(-) diff --git a/conn.go b/conn.go index b8131716..453f1a51 100644 --- a/conn.go +++ b/conn.go @@ -88,6 +88,10 @@ type Conn struct { closingLock sync.Mutex alive bool causeOfDeath error + + // context support + doneChan chan struct{} + closedChan chan struct{} } // PreparedStatement is a description of a prepared statement @@ -257,6 +261,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.channels = make(map[string]struct{}) c.alive = true c.lastActivityTime = time.Now() + c.doneChan = make(chan struct{}) + c.closedChan = make(chan struct{}) if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { @@ -619,8 +625,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { - return c.PrepareExContext(context.Background(), name, sql, opts) - + return c.prepareEx(name, sql, opts) } func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { @@ -630,25 +635,14 @@ func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *Pre default: } - doneChan := make(chan struct{}) - closedChan := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - c.cancelQuery() - c.Close() - closedChan <- struct{}{} - case <-doneChan: - } - }() + go c.contextHandler(ctx) ps, err = c.prepareEx(name, sql, opts) select { - case <-closedChan: + case <-c.closedChan: return nil, ctx.Err() - case doneChan <- struct{}{}: + case c.doneChan <- struct{}{}: return ps, err } } @@ -1383,25 +1377,24 @@ func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interfa default: } - doneChan := make(chan struct{}) - closedChan := make(chan struct{}) - - go func() { - select { - case <-ctx.Done(): - c.cancelQuery() - c.Close() - closedChan <- struct{}{} - case <-doneChan: - } - }() + go c.contextHandler(ctx) commandTag, err = c.Exec(sql, arguments...) select { - case <-closedChan: + case <-c.closedChan: return "", ctx.Err() - case doneChan <- struct{}{}: + case c.doneChan <- struct{}{}: return commandTag, err } } + +func (c *Conn) contextHandler(ctx context.Context) { + select { + case <-ctx.Done(): + c.cancelQuery() + c.Close() + c.closedChan <- struct{}{} + case <-c.doneChan: + } +} diff --git a/query.go b/query.go index 3ded881d..daf1b354 100644 --- a/query.go +++ b/query.go @@ -51,9 +51,7 @@ type Rows struct { unlockConn bool closed bool - ctx context.Context - doneChan chan struct{} - closedChan chan bool + ctx context.Context } func (rows *Rows) FieldDescriptions() []FieldDescription { @@ -128,9 +126,9 @@ func (rows *Rows) Close() { if rows.ctx != nil { select { - case <-rows.closedChan: + case <-rows.conn.closedChan: rows.err = rows.ctx.Err() - case rows.doneChan <- struct{}{}: + case rows.conn.doneChan <- struct{}{}: } } @@ -508,33 +506,20 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { - doneChan := make(chan struct{}) - closedChan := make(chan bool) - - go func() { - select { - case <-ctx.Done(): - c.cancelQuery() - c.Close() - closedChan <- true - case <-doneChan: - } - }() + go c.contextHandler(ctx) rows, err := c.Query(sql, args...) if err != nil { select { - case <-closedChan: + case <-c.closedChan: return rows, ctx.Err() - case doneChan <- struct{}{}: + case c.doneChan <- struct{}{}: return rows, err } } rows.ctx = ctx - rows.doneChan = doneChan - rows.closedChan = closedChan return rows, nil } From 72b6d32e2f841e6be96c5602c248b2875d345c3c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 Feb 2017 21:49:58 -0600 Subject: [PATCH 10/23] Extracted more context handling --- conn.go | 71 ++++++++++++++++++++++++++++++++++++---------------- conn_pool.go | 4 +++ query.go | 33 ++++++++---------------- 3 files changed, 64 insertions(+), 44 deletions(-) diff --git a/conn.go b/conn.go index 453f1a51..b662ba4c 100644 --- a/conn.go +++ b/conn.go @@ -90,8 +90,9 @@ type Conn struct { causeOfDeath error // context support - doneChan chan struct{} - closedChan chan struct{} + ctxInProgress bool + doneChan chan struct{} + closedChan chan error } // PreparedStatement is a description of a prepared statement @@ -262,7 +263,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.alive = true c.lastActivityTime = time.Now() c.doneChan = make(chan struct{}) - c.closedChan = make(chan struct{}) + c.closedChan = make(chan error) if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { @@ -629,22 +630,14 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: + err = c.initContext(ctx) + if err != nil { + return nil, err } - go c.contextHandler(ctx) - ps, err = c.prepareEx(name, sql, opts) - - select { - case <-c.closedChan: - return nil, ctx.Err() - case c.doneChan <- struct{}{}: - return ps, err - } + err = c.termContext(err) + return ps, err } func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { @@ -1371,22 +1364,56 @@ func (c *Conn) PingContext(ctx context.Context) error { } func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + err = c.initContext(ctx) + if err != nil { + return "", err + } + + commandTag, err = c.Exec(sql, arguments...) + err = c.termContext(err) + return commandTag, err +} + +func (c *Conn) initContext(ctx context.Context) error { + if c.ctxInProgress { + return errors.New("ctx already in progress") + } + + if ctx.Done() == nil { + return nil + } + select { case <-ctx.Done(): - return "", ctx.Err() + return ctx.Err() default: } + c.ctxInProgress = true + go c.contextHandler(ctx) - commandTag, err = c.Exec(sql, arguments...) + return nil +} + +func (c *Conn) termContext(opErr error) error { + if !c.ctxInProgress { + return opErr + } + + var err error select { - case <-c.closedChan: - return "", ctx.Err() + case err = <-c.closedChan: + if opErr == nil { + err = nil + } case c.doneChan <- struct{}{}: - return commandTag, err + err = opErr } + + c.ctxInProgress = false + return err } func (c *Conn) contextHandler(ctx context.Context) { @@ -1394,7 +1421,7 @@ func (c *Conn) contextHandler(ctx context.Context) { case <-ctx.Done(): c.cancelQuery() c.Close() - c.closedChan <- struct{}{} + c.closedChan <- ctx.Err() case <-c.doneChan: } } diff --git a/conn_pool.go b/conn_pool.go index 50b9d588..2a243a76 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -182,6 +182,10 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { // Release gives up use of a connection. func (p *ConnPool) Release(conn *Conn) { + if conn.ctxInProgress { + panic("should never release when context is in progress") + } + if conn.TxStatus != 'I' { conn.Exec("rollback") } diff --git a/query.go b/query.go index daf1b354..61136092 100644 --- a/query.go +++ b/query.go @@ -50,8 +50,6 @@ type Rows struct { afterClose func(*Rows) unlockConn bool closed bool - - ctx context.Context } func (rows *Rows) FieldDescriptions() []FieldDescription { @@ -84,6 +82,9 @@ func (rows *Rows) close() { } } +// TODO - consider inlining in Close(). This method calling rows.close is a +// foot-gun waiting to happen if anyone puts anything between the call to this +// and rows.close. func (rows *Rows) readUntilReadyForQuery() { for { t, r, err := rows.conn.rxMsg() @@ -122,16 +123,8 @@ func (rows *Rows) Close() { if rows.closed { return } + rows.err = rows.conn.termContext(rows.err) rows.readUntilReadyForQuery() - - if rows.ctx != nil { - select { - case <-rows.conn.closedChan: - rows.err = rows.ctx.Err() - case rows.conn.doneChan <- struct{}{}: - } - } - rows.close() } @@ -506,20 +499,16 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { - go c.contextHandler(ctx) - - rows, err := c.Query(sql, args...) - + err := c.initContext(ctx) if err != nil { - select { - case <-c.closedChan: - return rows, ctx.Err() - case c.doneChan <- struct{}{}: - return rows, err - } + return nil, err } - rows.ctx = ctx + rows, err := c.Query(sql, args...) + if err != nil { + err = c.termContext(err) + return nil, err + } return rows, nil } From b8fdc38fa861830ab82c6325a019af83e9270913 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 Feb 2017 19:37:23 -0600 Subject: [PATCH 11/23] Only store Conn's *bufio.Reader in msgReader Confusing and redundant to have the same *bufio.Reader in msgReader and Conn. --- conn.go | 10 ++++------ replication.go | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/conn.go b/conn.go index b662ba4c..7ecd18b2 100644 --- a/conn.go +++ b/conn.go @@ -61,9 +61,8 @@ func (cc *ConnConfig) networkAddress() (network, address string) { // Use ConnPool to manage access to multiple database connections from multiple // goroutines. type Conn struct { - conn net.Conn // the underlying TCP or unix domain socket connection - lastActivityTime time.Time // the last time the connection was used - reader *bufio.Reader // buffered reader to improve read performance + conn net.Conn // the underlying TCP or unix domain socket connection + lastActivityTime time.Time // the last time the connection was used wbuf [1024]byte writeBuf WriteBuf Pid int32 // backend pid @@ -274,8 +273,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - c.reader = bufio.NewReader(c.conn) - c.mr.reader = c.reader + c.mr.reader = bufio.NewReader(c.conn) msg := newStartupMessage() @@ -862,7 +860,7 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { } // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = c.reader.Peek(1) + _, err = c.mr.reader.Peek(1) if err != nil { c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline if err, ok := err.(*net.OpError); ok && err.Timeout() { diff --git a/replication.go b/replication.go index 7b28d6b6..12a5c914 100644 --- a/replication.go +++ b/replication.go @@ -289,7 +289,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(timeout time.Duration) (r * } // Wait until there is a byte available before continuing onto the normal msg reading path - _, err = rc.c.reader.Peek(1) + _, err = rc.c.mr.reader.Peek(1) if err != nil { rc.c.conn.SetReadDeadline(zeroTime) // we can only return one error and we already have one -- so ignore possiple error from SetReadDeadline if err, ok := err.(*net.OpError); ok && err.Timeout() { From 855276e2cf09ce6e53ee0c8876422b7975bf0667 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 Feb 2017 19:40:01 -0600 Subject: [PATCH 12/23] Remove unused msgReader.Err() --- msg_reader.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/msg_reader.go b/msg_reader.go index 21db5d26..f7b497f7 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -16,11 +16,6 @@ type msgReader struct { shouldLog func(lvl int) bool } -// Err returns any error that the msgReader has experienced -func (r *msgReader) Err() error { - return r.err -} - // fatal tells rc that a Fatal error has occurred func (r *msgReader) fatal(err error) { if r.shouldLog(LogLevelTrace) { From 50b0bea9e57b9c6181b4318bf3f7a89b03cb6ea9 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 Feb 2017 21:04:16 -0600 Subject: [PATCH 13/23] msgReader pre-buffers messages when possible --- msg_reader.go | 26 ++++++- msg_reader_test.go | 189 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 212 insertions(+), 3 deletions(-) create mode 100644 msg_reader_test.go diff --git a/msg_reader.go b/msg_reader.go index f7b497f7..1f4e67e9 100644 --- a/msg_reader.go +++ b/msg_reader.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "io" + "net" ) // msgReader is a helper that reads values from a PostgreSQL message. @@ -35,20 +36,39 @@ func (r *msgReader) rxMsg() (byte, error) { r.log(LogLevelTrace, "msgReader.rxMsg discarding unread previous message", "msgBytesRemaining", r.msgBytesRemaining) } - _, err := r.reader.Discard(int(r.msgBytesRemaining)) + n, err := r.reader.Discard(int(r.msgBytesRemaining)) + r.msgBytesRemaining -= int32(n) if err != nil { + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + r.fatal(err) + } return 0, err } } b, err := r.reader.Peek(5) if err != nil { - r.fatal(err) + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + r.fatal(err) + } return 0, err } + msgType := b[0] - r.msgBytesRemaining = int32(binary.BigEndian.Uint32(b[1:])) - 4 + payloadSize := int32(binary.BigEndian.Uint32(b[1:])) - 4 + + // Try to preload bufio.Reader with entire message + b, err = r.reader.Peek(5 + int(payloadSize)) + if err != nil && err != bufio.ErrBufferFull { + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + r.fatal(err) + } + return 0, err + } + + r.msgBytesRemaining = payloadSize r.reader.Discard(5) + return msgType, nil } diff --git a/msg_reader_test.go b/msg_reader_test.go new file mode 100644 index 00000000..2bbd53c9 --- /dev/null +++ b/msg_reader_test.go @@ -0,0 +1,189 @@ +package pgx + +import ( + "bufio" + "net" + "testing" + "time" + + "github.com/jackc/pgmock/pgmsg" +) + +func TestMsgReaderPrebuffersWhenPossible(t *testing.T) { + t.Parallel() + + tests := []struct { + msgType byte + payloadSize int32 + buffered bool + }{ + {1, 50, true}, + {2, 0, true}, + {3, 500, true}, + {4, 1050, true}, + {5, 1500, true}, + {6, 1500, true}, + {7, 4000, true}, + {8, 24000, false}, + {9, 4000, true}, + {1, 1500, true}, + {2, 0, true}, + {3, 500, true}, + {4, 1050, true}, + {5, 1500, true}, + {6, 1500, true}, + {7, 4000, true}, + {8, 14000, false}, + {9, 0, true}, + {1, 500, true}, + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + go func() { + var bigEndian pgmsg.BigEndianBuf + + conn, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + for _, tt := range tests { + _, err = conn.Write([]byte{tt.msgType}) + if err != nil { + t.Fatal(err) + } + + _, err = conn.Write(bigEndian.Int32(tt.payloadSize + 4)) + if err != nil { + t.Fatal(err) + } + + payload := make([]byte, int(tt.payloadSize)) + _, err = conn.Write(payload) + if err != nil { + t.Fatal(err) + } + } + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + mr := &msgReader{ + reader: bufio.NewReader(conn), + shouldLog: func(int) bool { return false }, + } + + for i, tt := range tests { + msgType, err := mr.rxMsg() + if err != nil { + t.Fatalf("%d. Unexpected error: %v", i, err) + } + + if msgType != tt.msgType { + t.Fatalf("%d. Expected %v, got %v", 1, i, tt.msgType, msgType) + } + + if mr.reader.Buffered() < int(tt.payloadSize) && tt.buffered { + t.Fatalf("%d. Expected message to be buffered with at least %d bytes, but only %v bytes buffered", i, tt.payloadSize, mr.reader.Buffered()) + } + } +} + +func TestMsgReaderDeadlineNeverInterruptsNormalSizedMessages(t *testing.T) { + t.Parallel() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + testCount := 10000 + + go func() { + var bigEndian pgmsg.BigEndianBuf + + conn, err := ln.Accept() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + for i := 0; i < testCount; i++ { + msgType := byte(i) + + _, err = conn.Write([]byte{msgType}) + if err != nil { + t.Fatal(err) + } + + msgSize := i % 4000 + + _, err = conn.Write(bigEndian.Int32(int32(msgSize + 4))) + if err != nil { + t.Fatal(err) + } + + payload := make([]byte, msgSize) + _, err = conn.Write(payload) + if err != nil { + t.Fatal(err) + } + } + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + mr := &msgReader{ + reader: bufio.NewReader(conn), + shouldLog: func(int) bool { return false }, + } + + conn.SetReadDeadline(time.Now().Add(time.Millisecond)) + + i := 0 + for { + msgType, err := mr.rxMsg() + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + conn.SetReadDeadline(time.Now().Add(time.Millisecond)) + continue + } else { + t.Fatalf("%d. Unexpected error: %v", i, err) + } + } + + expectedMsgType := byte(i) + if msgType != expectedMsgType { + t.Fatalf("%d. Expected %v, got %v", i, expectedMsgType, msgType) + } + + expectedMsgSize := i % 4000 + payload := mr.readBytes(mr.msgBytesRemaining) + if mr.err != nil { + t.Fatalf("%d. readBytes killed msgReader: %v", i, mr.err) + } + if len(payload) != expectedMsgSize { + t.Fatalf("%d. Expected %v, got %v", i, expectedMsgSize, len(payload)) + } + + i++ + if i == testCount { + break + } + } +} From 09d37880bafc78b43a429610d8825b095e9f24df Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Thu, 9 Feb 2017 21:42:58 -0600 Subject: [PATCH 14/23] wip --- conn-lock-todo.txt | 11 +++++++++++ conn.go | 7 ++++++- 2 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 conn-lock-todo.txt diff --git a/conn-lock-todo.txt b/conn-lock-todo.txt new file mode 100644 index 00000000..ab5eac95 --- /dev/null +++ b/conn-lock-todo.txt @@ -0,0 +1,11 @@ +Extract all locking state into a separate struct that will encapsulate locking and state change behavior. + +This struct should add or subsume at least the following: +* alive +* closingLock +* ctxInProgress (though this may be restructured because it's possible a Tx may have a ctx and a query run in that Tx could have one) +* busy +* lock/unlock +* Tx in-progress +* Rows in-progress +* ConnPool checked-out or checked-in - maybe include reference to conn pool diff --git a/conn.go b/conn.go index 7ecd18b2..78bdcedc 100644 --- a/conn.go +++ b/conn.go @@ -1403,6 +1403,9 @@ func (c *Conn) termContext(opErr error) error { select { case err = <-c.closedChan: + if dlErr := c.conn.SetDeadline(time.Time{}); dlErr != nil { + c.Close() // Close connection if unable to disable deadline + } if opErr == nil { err = nil } @@ -1418,7 +1421,9 @@ func (c *Conn) contextHandler(ctx context.Context) { select { case <-ctx.Done(): c.cancelQuery() - c.Close() + if err := c.conn.SetDeadline(time.Now()); err != nil { + c.Close() // Close connection if unable to set deadline + } c.closedChan <- ctx.Err() case <-c.doneChan: } From f0dfe4fe8926487e5772dade1decef121a7279ea Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 13:01:51 -0600 Subject: [PATCH 15/23] Merge alive and busy states into atomic status --- conn.go | 56 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/conn.go b/conn.go index 78bdcedc..7243a4d1 100644 --- a/conn.go +++ b/conn.go @@ -18,10 +18,17 @@ import ( "regexp" "strconv" "strings" - "sync" + "sync/atomic" "time" ) +const ( + connStatusUninitialized = iota + connStatusClosed + connStatusIdle + connStatusBusy +) + // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(network, addr string) (net.Conn, error) @@ -80,12 +87,10 @@ type Conn struct { fp *fastpath pgsqlAfInet *byte pgsqlAfInet6 *byte - busy bool poolResetCount int preallocatedRows []Rows - closingLock sync.Mutex - alive bool + status int32 // One of connStatus* constants causeOfDeath error // context support @@ -252,14 +257,14 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl defer func() { if c != nil && err != nil { c.conn.Close() - c.alive = false + atomic.StoreInt32(&c.status, connStatusClosed) } }() c.RuntimeParams = make(map[string]string) c.preparedStatements = make(map[string]*PreparedStatement) c.channels = make(map[string]struct{}) - c.alive = true + atomic.StoreInt32(&c.status, connStatusIdle) c.lastActivityTime = time.Now() c.doneChan = make(chan struct{}) c.closedChan = make(chan error) @@ -399,11 +404,14 @@ func (c *Conn) loadInetConstants() error { // Close closes a connection. It is safe to call Close on a already closed // connection. func (c *Conn) Close() (err error) { - c.closingLock.Lock() - defer c.closingLock.Unlock() - - if !c.alive { - return nil + for { + status := atomic.LoadInt32(&c.status) + if status < connStatusIdle { + return nil + } + if atomic.CompareAndSwapInt32(&c.status, status, connStatusClosed) { + break + } } _, err = c.conn.Write([]byte{'X', 0, 0, 0, 4}) @@ -893,10 +901,7 @@ func (c *Conn) waitForNotification(deadline time.Time) (*Notification, error) { } func (c *Conn) IsAlive() bool { - c.closingLock.Lock() - alive := c.alive - c.closingLock.Unlock() - return alive + return atomic.LoadInt32(&c.status) >= connStatusIdle } func (c *Conn) CauseOfDeath() error { @@ -1071,12 +1076,9 @@ func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { } func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { - c.closingLock.Lock() - if !c.alive { - c.closingLock.Unlock() + if atomic.LoadInt32(&c.status) < connStatusIdle { return 0, nil, ErrDeadConn } - c.closingLock.Unlock() t, err = c.mr.rxMsg() if err != nil { @@ -1261,25 +1263,23 @@ func (c *Conn) txPasswordMessage(password string) (err error) { } func (c *Conn) die(err error) { - c.alive = false + atomic.StoreInt32(&c.status, connStatusClosed) c.causeOfDeath = err c.conn.Close() } func (c *Conn) lock() error { - if c.busy { - return ErrConnBusy + if atomic.CompareAndSwapInt32(&c.status, connStatusIdle, connStatusBusy) { + return nil } - c.busy = true - return nil + return ErrConnBusy } func (c *Conn) unlock() error { - if !c.busy { - return errors.New("unlock conn that is not busy") + if atomic.CompareAndSwapInt32(&c.status, connStatusBusy, connStatusIdle) { + return nil } - c.busy = false - return nil + return errors.New("unlock conn that is not busy") } func (c *Conn) shouldLog(lvl int) bool { From e4f9108e8251f3a6e35c3bd698ad39273b172e9d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 14:59:16 -0600 Subject: [PATCH 16/23] wip --- conn.go | 87 +++++++++++++++++++++++++++++++++++++++++--------- conn_test.go | 2 +- copy_to.go | 1 - fastpath.go | 4 +++ helper_test.go | 16 ---------- query.go | 44 ++----------------------- query_test.go | 4 +-- 7 files changed, 82 insertions(+), 76 deletions(-) diff --git a/conn.go b/conn.go index 7243a4d1..f7443719 100644 --- a/conn.go +++ b/conn.go @@ -93,6 +93,8 @@ type Conn struct { status int32 // One of connStatus* constants causeOfDeath error + readyForQuery bool // can the connection be used to send a query + // context support ctxInProgress bool doneChan chan struct{} @@ -653,6 +655,10 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + if c.shouldLog(LogLevelError) { defer func() { if err != nil { @@ -692,6 +698,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared c.die(err) return nil, err } + c.readyForQuery = false ps = &PreparedStatement{Name: name, SQL: sql} @@ -706,7 +713,6 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } switch t { - case parseComplete: case parameterDescription: ps.ParameterOids = c.rxParameterDescription(r) @@ -720,7 +726,6 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared ps.FieldDescriptions[i].DataTypeName = t.Name ps.FieldDescriptions[i].FormatCode = t.DefaultFormat } - case noData: case readyForQuery: c.rxReadyForQuery(r) @@ -739,6 +744,10 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared // Deallocate released a prepared statement func (c *Conn) Deallocate(name string) (err error) { + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err + } + delete(c.preparedStatements, name) // close @@ -809,6 +818,10 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) return notification, nil } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + stopTime := time.Now().Add(timeout) for { @@ -916,6 +929,9 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) { } func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err + } if len(args) == 0 { wbuf := newWriteBuf(c, 'Q') @@ -927,6 +943,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { c.die(err) return err } + c.readyForQuery = false return nil } @@ -944,6 +961,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments)) } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err + } + // bind wbuf := newWriteBuf(c, 'B') wbuf.WriteByte(0) @@ -991,6 +1012,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} if err != nil { c.die(err) } + c.readyForQuery = false return err } @@ -1040,9 +1062,6 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag case readyForQuery: c.rxReadyForQuery(r) return commandTag, softErr - case rowDescription: - case dataRow: - case bindComplete: case commandComplete: commandTag = CommandTag(r.readCString()) default: @@ -1054,25 +1073,36 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag } // Processes messages that are not exclusive to one context such as -// authentication or query response. The response to these messages -// is the same regardless of when they occur. +// authentication or query response. The response to these messages is the same +// regardless of when they occur. It also ignores messages that are only +// meaningful in a given context. These messages can occur do to a context +// deadline interrupting message processing. For example, an interrupted query +// may have left DataRow messages on the wire. func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { switch t { - case 'S': - c.rxParameterStatus(r) - return nil + case bindComplete: + case commandComplete: + case dataRow: + case emptyQueryResponse: case errorResponse: return c.rxErrorResponse(r) + case noData: case noticeResponse: - return nil - case emptyQueryResponse: - return nil case notificationResponse: c.rxNotificationResponse(r) - return nil + case parameterDescription: + case parseComplete: + case readyForQuery: + c.rxReadyForQuery(r) + case rowDescription: + case 'S': + c.rxParameterStatus(r) + default: return fmt.Errorf("Received unknown message type: %c", t) } + + return nil } func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { @@ -1082,7 +1112,9 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { t, err = c.mr.rxMsg() if err != nil { - c.die(err) + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + c.die(err) + } } c.lastActivityTime = time.Now() @@ -1183,6 +1215,7 @@ func (c *Conn) rxBackendKeyData(r *msgReader) { } func (c *Conn) rxReadyForQuery(r *msgReader) { + c.readyForQuery = true c.TxStatus = r.readByte() } @@ -1428,3 +1461,27 @@ func (c *Conn) contextHandler(ctx context.Context) { case <-c.doneChan: } } + +func (c *Conn) ensureConnectionReadyForQuery() error { + for !c.readyForQuery { + t, r, err := c.rxMsg() + if err != nil { + return err + } + + switch t { + case errorResponse: + pgErr := c.rxErrorResponse(r) + if pgErr.Severity == "FATAL" { + return pgErr + } + default: + err = c.processContextFreeMsg(t, r) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/conn_test.go b/conn_test.go index e92c7ca3..ca39b4b4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -872,7 +872,7 @@ func TestExecContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled err, got %v", err) } - ensureConnDeadOnServer(t, conn, *defaultConnConfig) + ensureConnValid(t, conn) } func TestPrepare(t *testing.T) { diff --git a/copy_to.go b/copy_to.go index 91292bb0..dd70ada3 100644 --- a/copy_to.go +++ b/copy_to.go @@ -66,7 +66,6 @@ func (ct *copyTo) readUntilReadyForQuery() { ct.conn.rxReadyForQuery(r) close(ct.readerErrChan) return - case commandComplete: case errorResponse: ct.readerErrChan <- ct.conn.rxErrorResponse(r) default: diff --git a/fastpath.go b/fastpath.go index 19b98784..30a9f102 100644 --- a/fastpath.go +++ b/fastpath.go @@ -48,6 +48,10 @@ func fpInt64Arg(n int64) fpArg { } func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) { + if err := f.cn.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + wbuf := newWriteBuf(f.cn, 'F') // function call wbuf.WriteInt32(int32(oid)) // function object id wbuf.WriteInt16(1) // # of argument format codes diff --git a/helper_test.go b/helper_test.go index 997ae26f..21f86de5 100644 --- a/helper_test.go +++ b/helper_test.go @@ -71,19 +71,3 @@ func ensureConnValid(t *testing.T, conn *pgx.Conn) { t.Error("Wrong values returned") } } - -func ensureConnDeadOnServer(t *testing.T, conn *pgx.Conn, config pgx.ConnConfig) { - checkConn := mustConnect(t, config) - defer closeConn(t, checkConn) - - for i := 0; i < 10; i++ { - var found bool - err := checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) - if err == pgx.ErrNoRows { - return - } else if err != nil { - t.Fatalf("Unable to check if conn is dead on server: %v", err) - } - } - t.Fatal("Expected conn to be disconnected from server, but it wasn't") -} diff --git a/query.go b/query.go index 61136092..b6470688 100644 --- a/query.go +++ b/query.go @@ -82,41 +82,6 @@ func (rows *Rows) close() { } } -// TODO - consider inlining in Close(). This method calling rows.close is a -// foot-gun waiting to happen if anyone puts anything between the call to this -// and rows.close. -func (rows *Rows) readUntilReadyForQuery() { - for { - t, r, err := rows.conn.rxMsg() - if err != nil { - rows.close() - return - } - - switch t { - case readyForQuery: - rows.conn.rxReadyForQuery(r) - rows.close() - return - case rowDescription: - case dataRow: - case commandComplete: - case bindComplete: - case errorResponse: - err = rows.conn.rxErrorResponse(r) - if rows.err == nil { - rows.err = err - } - default: - err = rows.conn.processContextFreeMsg(t, r) - if err != nil { - rows.close() - return - } - } - } -} - // Close closes the rows, making the connection ready for use again. It is safe // to call Close after rows is already closed. func (rows *Rows) Close() { @@ -124,7 +89,6 @@ func (rows *Rows) Close() { return } rows.err = rows.conn.termContext(rows.err) - rows.readUntilReadyForQuery() rows.close() } @@ -174,10 +138,6 @@ func (rows *Rows) Next() bool { } switch t { - case readyForQuery: - rows.conn.rxReadyForQuery(r) - rows.close() - return false case dataRow: fieldCount := r.readInt16() if int(fieldCount) != len(rows.fields) { @@ -188,7 +148,9 @@ func (rows *Rows) Next() bool { rows.mr = r return true case commandComplete: - case bindComplete: + rows.close() + return false + default: err = rows.conn.processContextFreeMsg(t, r) if err != nil { diff --git a/query_test.go b/query_test.go index 40886f2e..24310ab3 100644 --- a/query_test.go +++ b/query_test.go @@ -1513,7 +1513,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled error, got %v", rows.Err()) } - ensureConnDeadOnServer(t, conn, *defaultConnConfig) + ensureConnValid(t, conn) } func TestQueryRowContextSuccess(t *testing.T) { @@ -1573,5 +1573,5 @@ func TestQueryRowContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled error, got %v", err) } - ensureConnDeadOnServer(t, conn, *defaultConnConfig) + ensureConnValid(t, conn) } From 8cc480fc485a73281cdbcc41bc937a970133c0bf Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 18:44:27 -0600 Subject: [PATCH 17/23] Fix grammar --- conn.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conn.go b/conn.go index f7443719..3ee0fe6b 100644 --- a/conn.go +++ b/conn.go @@ -1075,7 +1075,7 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag // Processes messages that are not exclusive to one context such as // authentication or query response. The response to these messages is the same // regardless of when they occur. It also ignores messages that are only -// meaningful in a given context. These messages can occur do to a context +// meaningful in a given context. These messages can occur due to a context // deadline interrupting message processing. For example, an interrupted query // may have left DataRow messages on the wire. func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { From 9c74626d226753b61b8bdf0103749511975b6f70 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 18:44:39 -0600 Subject: [PATCH 18/23] Ping implemented in terms of PingContext --- conn.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 3ee0fe6b..51fab0e5 100644 --- a/conn.go +++ b/conn.go @@ -1385,8 +1385,7 @@ func (c *Conn) cancelQuery() error { } func (c *Conn) Ping() error { - _, err := c.Exec(";") - return err + return c.PingContext(context.Background()) } func (c *Conn) PingContext(ctx context.Context) error { From 6cdb58fc71181d84efb08496242dcab3ab4247fc Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 18:46:46 -0600 Subject: [PATCH 19/23] Exec implemented in terms of ExecContext --- conn.go | 107 +++++++++++++++++++++++++++++--------------------------- 1 file changed, 55 insertions(+), 52 deletions(-) diff --git a/conn.go b/conn.go index 51fab0e5..5ede5944 100644 --- a/conn.go +++ b/conn.go @@ -1020,56 +1020,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} // Exec executes sql. sql can be either a prepared statement name or an SQL string. // arguments should be referenced positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, err error) { - if err = c.lock(); err != nil { - return commandTag, err - } - - startTime := time.Now() - c.lastActivityTime = startTime - - defer func() { - if err == nil { - if c.shouldLog(LogLevelInfo) { - endTime := time.Now() - c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag) - } - } else { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err) - } - } - - if unlockErr := c.unlock(); unlockErr != nil && err == nil { - err = unlockErr - } - }() - - if err = c.sendQuery(sql, arguments...); err != nil { - return - } - - var softErr error - - for { - var t byte - var r *msgReader - t, r, err = c.rxMsg() - if err != nil { - return commandTag, err - } - - switch t { - case readyForQuery: - c.rxReadyForQuery(r) - return commandTag, softErr - case commandComplete: - commandTag = CommandTag(r.readCString()) - default: - if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { - softErr = e - } - } - } + return c.ExecContext(context.Background(), sql, arguments...) } // Processes messages that are not exclusive to one context such as @@ -1398,9 +1349,61 @@ func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interfa if err != nil { return "", err } + defer func() { + err = c.termContext(err) + }() + + if err = c.lock(); err != nil { + return commandTag, err + } + + startTime := time.Now() + c.lastActivityTime = startTime + + defer func() { + if err == nil { + if c.shouldLog(LogLevelInfo) { + endTime := time.Now() + c.log(LogLevelInfo, "Exec", "sql", sql, "args", logQueryArgs(arguments), "time", endTime.Sub(startTime), "commandTag", commandTag) + } + } else { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, "Exec", "sql", sql, "args", logQueryArgs(arguments), "error", err) + } + } + + if unlockErr := c.unlock(); unlockErr != nil && err == nil { + err = unlockErr + } + }() + + if err = c.sendQuery(sql, arguments...); err != nil { + return + } + + var softErr error + + for { + var t byte + var r *msgReader + t, r, err = c.rxMsg() + if err != nil { + return commandTag, err + } + + switch t { + case readyForQuery: + c.rxReadyForQuery(r) + return commandTag, softErr + case commandComplete: + commandTag = CommandTag(r.readCString()) + default: + if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { + softErr = e + } + } + } - commandTag, err = c.Exec(sql, arguments...) - err = c.termContext(err) return commandTag, err } From deac6564eeb81e6ad3996b9e29f03854a8017f2d Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 19:16:13 -0600 Subject: [PATCH 20/23] Implement Query in terms of QueryContext - Merge Rows.close into Rows.Close - Merge Rows.abort into Rows.Fatal --- query.go | 91 ++++++++++++++++++++------------------------------ replication.go | 6 ++-- 2 files changed, 39 insertions(+), 58 deletions(-) diff --git a/query.go b/query.go index b6470688..aa664649 100644 --- a/query.go +++ b/query.go @@ -56,7 +56,9 @@ func (rows *Rows) FieldDescriptions() []FieldDescription { return rows.fields } -func (rows *Rows) close() { +// Close closes the rows, making the connection ready for use again. It is safe +// to call Close after rows is already closed. +func (rows *Rows) Close() { if rows.closed { return } @@ -68,6 +70,8 @@ func (rows *Rows) close() { rows.closed = true + rows.err = rows.conn.termContext(rows.err) + if rows.err == nil { if rows.conn.shouldLog(LogLevelInfo) { endTime := time.Now() @@ -82,31 +86,10 @@ func (rows *Rows) close() { } } -// Close closes the rows, making the connection ready for use again. It is safe -// to call Close after rows is already closed. -func (rows *Rows) Close() { - if rows.closed { - return - } - rows.err = rows.conn.termContext(rows.err) - rows.close() -} - func (rows *Rows) Err() error { return rows.err } -// abort signals that the query was not successfully sent to the server. -// This differs from Fatal in that it is not necessary to readUntilReadyForQuery -func (rows *Rows) abort(err error) { - if rows.err != nil { - return - } - - rows.err = err - rows.close() -} - // Fatal signals an error occurred after the query was sent to the server. It // closes the rows automatically. func (rows *Rows) Fatal(err error) { @@ -148,7 +131,7 @@ func (rows *Rows) Next() bool { rows.mr = r return true case commandComplete: - rows.close() + rows.Close() return false default: @@ -408,32 +391,7 @@ func (rows *Rows) AfterClose(f func(*Rows)) { // be returned in an error state. So it is allowed to ignore the error returned // from Query and handle it in *Rows. func (c *Conn) Query(sql string, args ...interface{}) (*Rows, error) { - c.lastActivityTime = time.Now() - - rows := c.getRows(sql, args) - - if err := c.lock(); err != nil { - rows.abort(err) - return rows, err - } - rows.unlockConn = true - - ps, ok := c.preparedStatements[sql] - if !ok { - var err error - ps, err = c.Prepare("", sql) - if err != nil { - rows.abort(err) - return rows, rows.err - } - } - rows.sql = ps.SQL - rows.fields = ps.FieldDescriptions - err := c.sendPreparedQuery(ps, args...) - if err != nil { - rows.abort(err) - } - return rows, rows.err + return c.QueryContext(context.Background(), sql, args...) } func (c *Conn) getRows(sql string, args []interface{}) *Rows { @@ -460,19 +418,42 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { return (*Row)(rows) } -func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { - err := c.initContext(ctx) +func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) { + c.lastActivityTime = time.Now() + + rows = c.getRows(sql, args) + + if err := c.lock(); err != nil { + rows.Fatal(err) + return rows, err + } + rows.unlockConn = true + + ps, ok := c.preparedStatements[sql] + if !ok { + var err error + ps, err = c.PrepareExContext(ctx, "", sql, nil) + if err != nil { + rows.Fatal(err) + return rows, rows.err + } + } + rows.sql = ps.SQL + rows.fields = ps.FieldDescriptions + + err = c.initContext(ctx) if err != nil { - return nil, err + rows.Fatal(err) + return rows, err } - rows, err := c.Query(sql, args...) + err = c.sendPreparedQuery(ps, args...) if err != nil { + rows.Fatal(err) err = c.termContext(err) - return nil, err } - return rows, nil + return rows, err } func (c *Conn) QueryRowContext(ctx context.Context, sql string, args ...interface{}) *Row { diff --git a/replication.go b/replication.go index 12a5c914..0acc9df9 100644 --- a/replication.go +++ b/replication.go @@ -312,14 +312,14 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { rows := rc.c.getRows(sql, nil) if err := rc.c.lock(); err != nil { - rows.abort(err) + rows.Fatal(err) return rows, err } rows.unlockConn = true err := rc.c.sendSimpleQuery(sql) if err != nil { - rows.abort(err) + rows.Fatal(err) } var t byte @@ -337,7 +337,7 @@ func (rc *ReplicationConn) sendReplicationModeQuery(sql string) (*Rows, error) { // only Oids. Not much we can do about this. default: if e := rc.c.processContextFreeMsg(t, r); e != nil { - rows.abort(e) + rows.Fatal(e) return rows, e } } From 048a75406f1139b19f1be31f3ec2f590c901fc8e Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 19:53:18 -0600 Subject: [PATCH 21/23] Fix context query cancellation Previous commits had a race condition due to not waiting for the PostgreSQL server to close the cancel query connection. This made it possible for the cancel request to impact a subsequent query on the same connection. This commit sets a flag that a cancel request was made and blocks until the PostgreSQL server closes the cancel connection. --- conn.go | 128 ++++++++++++++++++++++++++++++++++++++++--------- query.go | 5 ++ stress_test.go | 14 +++--- 3 files changed, 118 insertions(+), 29 deletions(-) diff --git a/conn.go b/conn.go index 5ede5944..f91929c5 100644 --- a/conn.go +++ b/conn.go @@ -93,7 +93,9 @@ type Conn struct { status int32 // One of connStatus* constants causeOfDeath error - readyForQuery bool // can the connection be used to send a query + readyForQuery bool // connection has received ReadyForQuery message since last query was sent + cancelQueryInProgress int32 + cancelQueryCompleted chan struct{} // context support ctxInProgress bool @@ -268,6 +270,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.channels = make(map[string]struct{}) atomic.StoreInt32(&c.status, connStatusIdle) c.lastActivityTime = time.Now() + c.cancelQueryCompleted = make(chan struct{}, 1) c.doneChan = make(chan struct{}) c.closedChan = make(chan error) @@ -634,10 +637,15 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { - return c.prepareEx(name, sql, opts) + return c.PrepareExContext(context.Background(), name, sql, opts) } func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return nil, err + } + err = c.initContext(ctx) if err != nil { return nil, err @@ -743,7 +751,25 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } // Deallocate released a prepared statement -func (c *Conn) Deallocate(name string) (err error) { +func (c *Conn) Deallocate(name string) error { + return c.deallocateContext(context.Background(), name) +} + +// TODO - consider making this public +func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return err + } + + err = c.initContext(ctx) + if err != nil { + return err + } + defer func() { + err = c.termContext(err) + }() + if err := c.ensureConnectionReadyForQuery(); err != nil { return err } @@ -818,6 +844,13 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) return notification, nil } + ctx, cancelFn := context.WithTimeout(context.Background(), timeout) + if err := c.waitForPreviousCancelQuery(ctx); err != nil { + cancelFn() + return nil, err + } + cancelFn() + if err := c.ensureConnectionReadyForQuery(); err != nil { return nil, err } @@ -1318,21 +1351,55 @@ func quoteIdentifier(s string) string { // ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. See // https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 -func (c *Conn) cancelQuery() error { - network, address := c.config.networkAddress() - cancelConn, err := c.config.Dial(network, address) - if err != nil { - return err +func (c *Conn) cancelQuery() { + if !atomic.CompareAndSwapInt32(&c.cancelQueryInProgress, 0, 1) { + panic("cancelQuery when cancelQueryInProgress") } - defer cancelConn.Close() - buf := make([]byte, 16) - binary.BigEndian.PutUint32(buf[0:4], 16) - binary.BigEndian.PutUint32(buf[4:8], 80877102) - binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid)) - binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey)) - _, err = cancelConn.Write(buf) - return err + if err := c.conn.SetDeadline(time.Now()); err != nil { + c.Close() // Close connection if unable to set deadline + return + } + + doCancel := func() error { + network, address := c.config.networkAddress() + cancelConn, err := c.config.Dial(network, address) + if err != nil { + return err + } + defer cancelConn.Close() + + // If server doesn't process cancellation request in bounded time then abort. + err = cancelConn.SetDeadline(time.Now().Add(15 * time.Second)) + if err != nil { + return err + } + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid)) + binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey)) + _, err = cancelConn.Write(buf) + if err != nil { + return err + } + + _, err = cancelConn.Read(buf) + if err != io.EOF { + return fmt.Errorf("Server failed to close connection after cancel query request: %v %v", err, buf) + } + + return nil + } + + go func() { + err := doCancel() + if err != nil { + c.Close() // Something is very wrong. Terminate the connection. + } + c.cancelQueryCompleted <- struct{}{} + }() } func (c *Conn) Ping() error { @@ -1345,6 +1412,11 @@ func (c *Conn) PingContext(ctx context.Context) error { } func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return "", err + } + err = c.initContext(ctx) if err != nil { return "", err @@ -1438,9 +1510,6 @@ func (c *Conn) termContext(opErr error) error { select { case err = <-c.closedChan: - if dlErr := c.conn.SetDeadline(time.Time{}); dlErr != nil { - c.Close() // Close connection if unable to disable deadline - } if opErr == nil { err = nil } @@ -1456,14 +1525,29 @@ func (c *Conn) contextHandler(ctx context.Context) { select { case <-ctx.Done(): c.cancelQuery() - if err := c.conn.SetDeadline(time.Now()); err != nil { - c.Close() // Close connection if unable to set deadline - } c.closedChan <- ctx.Err() case <-c.doneChan: } } +func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { + if atomic.LoadInt32(&c.cancelQueryInProgress) == 0 { + return nil + } + + select { + case <-c.cancelQueryCompleted: + atomic.StoreInt32(&c.cancelQueryInProgress, 0) + if err := c.conn.SetDeadline(time.Time{}); err != nil { + c.Close() // Close connection if unable to disable deadline + return err + } + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + func (c *Conn) ensureConnectionReadyForQuery() error { for !c.readyForQuery { t, r, err := c.rxMsg() diff --git a/query.go b/query.go index aa664649..dd7aafb0 100644 --- a/query.go +++ b/query.go @@ -419,6 +419,11 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (rows *Rows, err error) { + err = c.waitForPreviousCancelQuery(ctx) + if err != nil { + return nil, err + } + c.lastActivityTime = time.Now() rows = c.getRows(sql, args) diff --git a/stress_test.go b/stress_test.go index d22d9d6b..72d48a5c 100644 --- a/stress_test.go +++ b/stress_test.go @@ -66,7 +66,7 @@ func TestStressConnPool(t *testing.T) { action := actions[rand.Intn(len(actions))] err := action.fn(pool, n) if err != nil { - errChan <- err + errChan <- fmt.Errorf("%s: %v", action.name, err) break } } @@ -355,19 +355,19 @@ func canceledQueryContext(pool *pgx.ConnPool, actionNum int) error { cancelFunc() }() - rows, err := pool.QueryContext(ctx, "select pg_sleep(5)") + rows, err := pool.QueryContext(ctx, "select pg_sleep(2)") if err == context.Canceled { return nil } else if err != nil { - return fmt.Errorf("canceledQueryContext: Only allowed error is context.Canceled, got %v", err) + return fmt.Errorf("Only allowed error is context.Canceled, got %v", err) } for rows.Next() { - return errors.New("canceledQueryContext: should never receive row") + return errors.New("should never receive row") } if rows.Err() != context.Canceled { - return fmt.Errorf("canceledQueryContext: Expected context.Canceled error, got %v", rows.Err()) + return fmt.Errorf("Expected context.Canceled error, got %v", rows.Err()) } return nil @@ -380,9 +380,9 @@ func canceledExecContext(pool *pgx.ConnPool, actionNum int) error { cancelFunc() }() - _, err := pool.ExecContext(ctx, "select pg_sleep(5)") + _, err := pool.ExecContext(ctx, "select pg_sleep(2)") if err != context.Canceled { - return fmt.Errorf("canceledExecContext: Expected context.Canceled error, got %v", err) + return fmt.Errorf("Expected context.Canceled error, got %v", err) } return nil From d0a6921d124dfab48c89004e1a683bce180b795f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 11 Feb 2017 20:40:28 -0600 Subject: [PATCH 22/23] Add dependency to travis.yml --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index 4a3b91e2..9ae8d963 100644 --- a/.travis.yml +++ b/.travis.yml @@ -52,6 +52,7 @@ install: - go get -u gopkg.in/inconshreveable/log15.v2 - go get -u github.com/jackc/fake - go get -u golang.org/x/net/context + - go get -u github.com/jackc/pgmock/pgmsg script: - go test -v -race -short ./... From cc414269c1bbca67c779c9798c13bb78c0a1843f Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sun, 12 Feb 2017 08:12:36 -0600 Subject: [PATCH 23/23] Remove debugging Println --- stdlib/sql.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/stdlib/sql.go b/stdlib/sql.go index 74218a7b..41c9d4dd 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -255,8 +255,6 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr return nil, err } - fmt.Println("ere") - return &Rows{rows: rows}, nil }