From 78adfb13d796427deafa89fc45aa5c7e47f8d51b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 4 Feb 2017 14:20:00 -0600 Subject: [PATCH] Add Ping, PingContext, and ExecContext --- conn.go | 96 ++++++++++++++++++++++++++++++++++++++++++++++------ conn_test.go | 68 +++++++++++++++++++++++++++++++++++++ 2 files changed, 153 insertions(+), 11 deletions(-) diff --git a/conn.go b/conn.go index 602ecbff..645b9c5d 100644 --- a/conn.go +++ b/conn.go @@ -8,6 +8,7 @@ import ( "encoding/hex" "errors" "fmt" + "golang.org/x/net/context" "io" "net" "net/url" @@ -39,6 +40,22 @@ type ConnConfig struct { RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name) } +func (cc *ConnConfig) networkAddress() (network, address string) { + network = "tcp" + address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) + // See if host is a valid path, if yes connect with a socket + if _, err := os.Stat(cc.Host); err == nil { + // For backward compatibility accept socket file paths -- but directories are now preferred + network = "unix" + address = cc.Host + if !strings.Contains(address, "/.s.PGSQL.") { + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) + } + } + + return network, address +} + // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. // Use ConnPool to manage access to multiple database connections from multiple // goroutines. @@ -194,17 +211,7 @@ func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsql } } - network := "tcp" - address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) - // See if host is a valid path, if yes connect with a socket - if _, err := os.Stat(c.config.Host); err == nil { - // For backward compatibility accept socket file paths -- but directories are now preferred - network = "unix" - address = c.config.Host - if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10) - } - } + network, address := c.config.networkAddress() if c.config.Dial == nil { c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial } @@ -1292,3 +1299,70 @@ func (c *Conn) SetLogLevel(lvl int) (int, error) { func quoteIdentifier(s string) string { return `"` + strings.Replace(s, `"`, `""`, -1) + `"` } + +// cancelQuery sends a cancel request to the PostgreSQL server. It returns an +// error if unable to deliver the cancel request, but lack of an error does not +// ensure that the query was canceled. As specified in the documentation, there +// is no way to be sure a query was canceled. See +// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 +func (c *Conn) cancelQuery() error { + network, address := c.config.networkAddress() + cancelConn, err := c.config.Dial(network, address) + if err != nil { + return err + } + defer cancelConn.Close() + + buf := make([]byte, 16) + binary.BigEndian.PutUint32(buf[0:4], 16) + binary.BigEndian.PutUint32(buf[4:8], 80877102) + binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid)) + binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey)) + _, err = cancelConn.Write(buf) + return err +} + +func (c *Conn) Ping() error { + _, err := c.Exec(";") + return err +} + +func (c *Conn) PingContext(ctx context.Context) error { + _, err := c.ExecContext(ctx, ";") + return err +} + +func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + select { + case <-ctx.Done(): + return "", ctx.Err() + default: + } + + doneChan := make(chan struct{}) + closedChan := make(chan bool) + + go func() { + select { + case <-ctx.Done(): + c.cancelQuery() + c.Close() + <-doneChan + closedChan <- true + case <-doneChan: + closedChan <- false + } + }() + + commandTag, err = c.Exec(sql, arguments...) + + // Signal cancelation goroutine that operation is done + doneChan <- struct{}{} + + // If c was closed due to context cancelation then return context err + if <-closedChan { + return "", ctx.Err() + } + + return commandTag, err +} diff --git a/conn_test.go b/conn_test.go index 9ed073ce..a9cf02c9 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,6 +3,7 @@ package pgx_test import ( "crypto/tls" "fmt" + "golang.org/x/net/context" "net" "os" "reflect" @@ -816,6 +817,73 @@ func TestExecFailure(t *testing.T) { } } +func TestExecContextWithoutCancelation(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + commandTag, err := conn.ExecContext(ctx, "create temporary table foo(id integer primary key);") + if err != nil { + t.Fatal(err) + } + if commandTag != "CREATE TABLE" { + t.Fatalf("Unexpected results from ExecContext: %v", commandTag) + } +} + +func TestExecContextFailureWithoutCancelation(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + + if _, err := conn.ExecContext(ctx, "selct;"); err == nil { + t.Fatal("Expected SQL syntax error") + } + + rows, _ := conn.Query("select 1") + rows.Close() + if rows.Err() != nil { + t.Fatalf("ExecContext failure appears to have broken connection: %v", rows.Err()) + } +} + +func TestExecContextCancelationCancelsQuery(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ctx, cancelFunc := context.WithCancel(context.Background()) + go func() { + time.Sleep(500 * time.Millisecond) + cancelFunc() + }() + + _, err := conn.ExecContext(ctx, "select pg_sleep(60)") + if err != context.Canceled { + t.Fatal("Expected context.Canceled err, got %v", err) + } + + time.Sleep(500 * time.Millisecond) + + checkConn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, checkConn) + + var found bool + err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) + if err != pgx.ErrNoRows { + t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't") + } +} + func TestPrepare(t *testing.T) { t.Parallel()