Add cancel request to PgConn

RecoverFromTimeout automatically tries to cancel in progress requests.
pull/483/head
Jack Christensen 2018-12-31 17:32:04 -06:00
parent 084423ae69
commit a8ac061b6a
2 changed files with 63 additions and 0 deletions

View File

@ -562,8 +562,14 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool {
}
pgConn.resetBatch()
// Clear any existing timeout
pgConn.NetConn.SetDeadline(time.Time{})
// Try to cancel any in-progress requests
for i := 0; i < int(pgConn.pendingReadyForQueryCount); i++ {
pgConn.CancelRequest(ctx)
}
cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn)
defer cleanupContext()
@ -669,3 +675,38 @@ func errorResponseToPgError(msg *pgproto3.ErrorResponse) PgError {
Routine: msg.Routine,
}
}
// CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel
// request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there
// is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9
func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing
// the connection config. This is important in high availability configurations where fallback connections may be
// specified or DNS may be used to load balance.
serverAddr := pgConn.NetConn.RemoteAddr()
cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String())
if err != nil {
return err
}
defer cancelConn.Close()
cleanupContext := contextDoneToConnDeadline(ctx, cancelConn)
defer cleanupContext()
buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
binary.BigEndian.PutUint32(buf[4:8], 80877102)
binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.PID))
binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.SecretKey))
_, err = cancelConn.Write(buf)
if err != nil {
return preferContextOverNetTimeoutError(ctx, err)
}
_, err = cancelConn.Read(buf)
if err != io.EOF {
return fmt.Errorf("Server failed to close connection after cancel query request: %v", preferContextOverNetTimeoutError(ctx, err))
}
return nil
}

View File

@ -264,6 +264,8 @@ func TestConnExecContextCanceled(t *testing.T) {
result, err := pgConn.Exec(ctx, "select current_database(), pg_sleep(1)")
require.Nil(t, result)
assert.Equal(t, context.DeadlineExceeded, err)
assert.True(t, pgConn.RecoverFromTimeout(context.Background()))
}
func TestConnRecoverFromTimeout(t *testing.T) {
@ -287,3 +289,23 @@ func TestConnRecoverFromTimeout(t *testing.T) {
}
cancel()
}
func TestConnCancelQuery(t *testing.T) {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.Nil(t, err)
defer closeConn(t, pgConn)
pgConn.SendExec("select current_database(), pg_sleep(5)")
err = pgConn.Flush(context.Background())
require.Nil(t, err)
err = pgConn.CancelRequest(context.Background())
require.Nil(t, err)
_, err = pgConn.GetResult(context.Background()).Close()
if err, ok := err.(pgconn.PgError); ok {
assert.Equal(t, "57014", err.Code)
} else {
t.Errorf("expected pgconn.PgError got %v", err)
}
}