Add Ping, PingContext, and ExecContext

context
Jack Christensen 2017-02-04 14:20:00 -06:00
parent a52a6bd555
commit 78adfb13d7
2 changed files with 153 additions and 11 deletions

96
conn.go
View File

@ -8,6 +8,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"golang.org/x/net/context"
"io"
"net"
"net/url"
@ -39,6 +40,22 @@ type ConnConfig struct {
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
}
func (cc *ConnConfig) networkAddress() (network, address string) {
network = "tcp"
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
// See if host is a valid path, if yes connect with a socket
if _, err := os.Stat(cc.Host); err == nil {
// For backward compatibility accept socket file paths -- but directories are now preferred
network = "unix"
address = cc.Host
if !strings.Contains(address, "/.s.PGSQL.") {
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
}
}
return network, address
}
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
// Use ConnPool to manage access to multiple database connections from multiple
// goroutines.
@ -194,17 +211,7 @@ func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsql
}
}
network := "tcp"
address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
// See if host is a valid path, if yes connect with a socket
if _, err := os.Stat(c.config.Host); err == nil {
// For backward compatibility accept socket file paths -- but directories are now preferred
network = "unix"
address = c.config.Host
if !strings.Contains(address, "/.s.PGSQL.") {
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
}
}
network, address := c.config.networkAddress()
if c.config.Dial == nil {
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
}
@ -1292,3 +1299,70 @@ func (c *Conn) SetLogLevel(lvl int) (int, error) {
func quoteIdentifier(s string) string {
return `"` + strings.Replace(s, `"`, `""`, -1) + `"`
}
// cancelQuery sends a cancel request to the PostgreSQL server. It returns an
// error if unable to deliver the cancel request, but lack of an error does not
// ensure that the query was canceled. As specified in the documentation, there
// is no way to be sure a query was canceled. See
// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861
func (c *Conn) cancelQuery() error {
network, address := c.config.networkAddress()
cancelConn, err := c.config.Dial(network, address)
if err != nil {
return err
}
defer cancelConn.Close()
buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
binary.BigEndian.PutUint32(buf[4:8], 80877102)
binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid))
binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey))
_, err = cancelConn.Write(buf)
return err
}
func (c *Conn) Ping() error {
_, err := c.Exec(";")
return err
}
func (c *Conn) PingContext(ctx context.Context) error {
_, err := c.ExecContext(ctx, ";")
return err
}
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
select {
case <-ctx.Done():
return "", ctx.Err()
default:
}
doneChan := make(chan struct{})
closedChan := make(chan bool)
go func() {
select {
case <-ctx.Done():
c.cancelQuery()
c.Close()
<-doneChan
closedChan <- true
case <-doneChan:
closedChan <- false
}
}()
commandTag, err = c.Exec(sql, arguments...)
// Signal cancelation goroutine that operation is done
doneChan <- struct{}{}
// If c was closed due to context cancelation then return context err
if <-closedChan {
return "", ctx.Err()
}
return commandTag, err
}

View File

@ -3,6 +3,7 @@ package pgx_test
import (
"crypto/tls"
"fmt"
"golang.org/x/net/context"
"net"
"os"
"reflect"
@ -816,6 +817,73 @@ func TestExecFailure(t *testing.T) {
}
}
func TestExecContextWithoutCancelation(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
commandTag, err := conn.ExecContext(ctx, "create temporary table foo(id integer primary key);")
if err != nil {
t.Fatal(err)
}
if commandTag != "CREATE TABLE" {
t.Fatalf("Unexpected results from ExecContext: %v", commandTag)
}
}
func TestExecContextFailureWithoutCancelation(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
if _, err := conn.ExecContext(ctx, "selct;"); err == nil {
t.Fatal("Expected SQL syntax error")
}
rows, _ := conn.Query("select 1")
rows.Close()
if rows.Err() != nil {
t.Fatalf("ExecContext failure appears to have broken connection: %v", rows.Err())
}
}
func TestExecContextCancelationCancelsQuery(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
ctx, cancelFunc := context.WithCancel(context.Background())
go func() {
time.Sleep(500 * time.Millisecond)
cancelFunc()
}()
_, err := conn.ExecContext(ctx, "select pg_sleep(60)")
if err != context.Canceled {
t.Fatal("Expected context.Canceled err, got %v", err)
}
time.Sleep(500 * time.Millisecond)
checkConn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, checkConn)
var found bool
err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found)
if err != pgx.ErrNoRows {
t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't")
}
}
func TestPrepare(t *testing.T) {
t.Parallel()