mirror of https://github.com/jackc/pgx.git
Add Ping, PingContext, and ExecContext
parent
a52a6bd555
commit
78adfb13d7
96
conn.go
96
conn.go
|
@ -8,6 +8,7 @@ import (
|
|||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"golang.org/x/net/context"
|
||||
"io"
|
||||
"net"
|
||||
"net/url"
|
||||
|
@ -39,6 +40,22 @@ type ConnConfig struct {
|
|||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
}
|
||||
|
||||
func (cc *ConnConfig) networkAddress() (network, address string) {
|
||||
network = "tcp"
|
||||
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
|
||||
// See if host is a valid path, if yes connect with a socket
|
||||
if _, err := os.Stat(cc.Host); err == nil {
|
||||
// For backward compatibility accept socket file paths -- but directories are now preferred
|
||||
network = "unix"
|
||||
address = cc.Host
|
||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
|
||||
}
|
||||
}
|
||||
|
||||
return network, address
|
||||
}
|
||||
|
||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
// Use ConnPool to manage access to multiple database connections from multiple
|
||||
// goroutines.
|
||||
|
@ -194,17 +211,7 @@ func connect(config ConnConfig, pgTypes map[Oid]PgType, pgsqlAfInet *byte, pgsql
|
|||
}
|
||||
}
|
||||
|
||||
network := "tcp"
|
||||
address := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port)
|
||||
// See if host is a valid path, if yes connect with a socket
|
||||
if _, err := os.Stat(c.config.Host); err == nil {
|
||||
// For backward compatibility accept socket file paths -- but directories are now preferred
|
||||
network = "unix"
|
||||
address = c.config.Host
|
||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(c.config.Port), 10)
|
||||
}
|
||||
}
|
||||
network, address := c.config.networkAddress()
|
||||
if c.config.Dial == nil {
|
||||
c.config.Dial = (&net.Dialer{KeepAlive: 5 * time.Minute}).Dial
|
||||
}
|
||||
|
@ -1292,3 +1299,70 @@ func (c *Conn) SetLogLevel(lvl int) (int, error) {
|
|||
func quoteIdentifier(s string) string {
|
||||
return `"` + strings.Replace(s, `"`, `""`, -1) + `"`
|
||||
}
|
||||
|
||||
// cancelQuery sends a cancel request to the PostgreSQL server. It returns an
|
||||
// error if unable to deliver the cancel request, but lack of an error does not
|
||||
// ensure that the query was canceled. As specified in the documentation, there
|
||||
// is no way to be sure a query was canceled. See
|
||||
// https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861
|
||||
func (c *Conn) cancelQuery() error {
|
||||
network, address := c.config.networkAddress()
|
||||
cancelConn, err := c.config.Dial(network, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancelConn.Close()
|
||||
|
||||
buf := make([]byte, 16)
|
||||
binary.BigEndian.PutUint32(buf[0:4], 16)
|
||||
binary.BigEndian.PutUint32(buf[4:8], 80877102)
|
||||
binary.BigEndian.PutUint32(buf[8:12], uint32(c.Pid))
|
||||
binary.BigEndian.PutUint32(buf[12:16], uint32(c.SecretKey))
|
||||
_, err = cancelConn.Write(buf)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) Ping() error {
|
||||
_, err := c.Exec(";")
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) PingContext(ctx context.Context) error {
|
||||
_, err := c.ExecContext(ctx, ";")
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return "", ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
doneChan := make(chan struct{})
|
||||
closedChan := make(chan bool)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.cancelQuery()
|
||||
c.Close()
|
||||
<-doneChan
|
||||
closedChan <- true
|
||||
case <-doneChan:
|
||||
closedChan <- false
|
||||
}
|
||||
}()
|
||||
|
||||
commandTag, err = c.Exec(sql, arguments...)
|
||||
|
||||
// Signal cancelation goroutine that operation is done
|
||||
doneChan <- struct{}{}
|
||||
|
||||
// If c was closed due to context cancelation then return context err
|
||||
if <-closedChan {
|
||||
return "", ctx.Err()
|
||||
}
|
||||
|
||||
return commandTag, err
|
||||
}
|
||||
|
|
68
conn_test.go
68
conn_test.go
|
@ -3,6 +3,7 @@ package pgx_test
|
|||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"golang.org/x/net/context"
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
|
@ -816,6 +817,73 @@ func TestExecFailure(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExecContextWithoutCancelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
commandTag, err := conn.ExecContext(ctx, "create temporary table foo(id integer primary key);")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if commandTag != "CREATE TABLE" {
|
||||
t.Fatalf("Unexpected results from ExecContext: %v", commandTag)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecContextFailureWithoutCancelation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
|
||||
if _, err := conn.ExecContext(ctx, "selct;"); err == nil {
|
||||
t.Fatal("Expected SQL syntax error")
|
||||
}
|
||||
|
||||
rows, _ := conn.Query("select 1")
|
||||
rows.Close()
|
||||
if rows.Err() != nil {
|
||||
t.Fatalf("ExecContext failure appears to have broken connection: %v", rows.Err())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecContextCancelationCancelsQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
cancelFunc()
|
||||
}()
|
||||
|
||||
_, err := conn.ExecContext(ctx, "select pg_sleep(60)")
|
||||
if err != context.Canceled {
|
||||
t.Fatal("Expected context.Canceled err, got %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
checkConn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, checkConn)
|
||||
|
||||
var found bool
|
||||
err = checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found)
|
||||
if err != pgx.ErrNoRows {
|
||||
t.Fatal("Expected context canceled connection to be disconnected from server, but it wasn't")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepare(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
|
Loading…
Reference in New Issue