Enable passing nil context

query-exec-mode
bakape 2020-01-01 13:09:50 +02:00
parent 3e503b7b1a
commit 89416dd805
3 changed files with 116 additions and 77 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
.envrc
vendor/
.vscode

3
doc.go
View File

@ -23,6 +23,9 @@ Context Support
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
method immediately returns. In most circumstances, this will close the underlying connection.
A nil context can be passed for convenience. This has the same effect as passing context.Background() with an additional
slight performance increase, if you don't need the operation to be cancellable.
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
client to abort.
*/

View File

@ -116,6 +116,10 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
panic("config must be created by ParseConfig")
}
if ctx == nil {
ctx = context.Background()
}
// Simplify usage by treating primary config and fallbacks the same.
fallbackConfigs := []*FallbackConfig{
{
@ -362,6 +366,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
}
defer pgConn.unlock()
if ctx != nil {
select {
case <-ctx.Done():
return &contextAlreadyDoneError{err: ctx.Err()}
@ -369,6 +374,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
n, err := pgConn.conn.Write(buf)
if err != nil {
@ -392,6 +398,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
}
defer pgConn.unlock()
if ctx != nil {
select {
case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()}
@ -399,6 +406,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
msg, err := pgConn.receiveMessage()
if err != nil {
@ -489,8 +497,10 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
defer pgConn.conn.Close()
if ctx != nil {
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
// Ignore any errors sending Terminate message and waiting for server to close connection.
// This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully
@ -586,6 +596,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
}
defer pgConn.unlock()
if ctx != nil {
select {
case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()}
@ -593,6 +604,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
buf := pgConn.wbuf
buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
@ -673,18 +685,24 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// the connection config. This is important in high availability configurations where fallback connections may be
// specified or DNS may be used to load balance.
serverAddr := pgConn.conn.RemoteAddr()
cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String())
_ctx := ctx
if _ctx == nil {
_ctx = context.Background()
}
cancelConn, err := pgConn.config.DialFunc(_ctx, serverAddr.Network(), serverAddr.String())
if err != nil {
return err
}
defer cancelConn.Close()
if ctx != nil {
contextWatcher := ctxwatch.NewContextWatcher(
func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
func() { cancelConn.SetDeadline(time.Time{}) },
)
contextWatcher.Watch(ctx)
defer contextWatcher.Unwatch()
}
buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
@ -712,6 +730,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
}
defer pgConn.unlock()
if ctx != nil {
select {
case <-ctx.Done():
return ctx.Err()
@ -720,6 +739,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
for {
msg, err := pgConn.receiveMessage()
@ -752,7 +772,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
ctx: ctx,
}
multiResult := &pgConn.multiResultReader
if ctx != nil {
select {
case <-ctx.Done():
multiResult.closed = true
@ -762,6 +782,9 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
default:
}
pgConn.contextWatcher.Watch(ctx)
} else {
pgConn.multiResultReader.ctx = context.Background()
}
buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf)
@ -808,7 +831,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
pgConn.execExtendedSuffix(ctx, buf, result)
pgConn.execExtendedSuffix(buf, result)
return result
}
@ -834,7 +857,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
buf := pgConn.wbuf
buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
pgConn.execExtendedSuffix(ctx, buf, result)
pgConn.execExtendedSuffix(buf, result)
return result
}
@ -845,6 +868,9 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
ctx: ctx,
}
result := &pgConn.resultReader
if ctx == nil {
pgConn.resultReader.ctx = context.Background()
}
if err := pgConn.lock(); err != nil {
result.concludeCommand(nil, err)
@ -859,6 +885,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
return result
}
if ctx != nil {
select {
case <-ctx.Done():
result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()})
@ -868,11 +895,12 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
default:
}
pgConn.contextWatcher.Watch(ctx)
}
return result
}
func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) {
func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf)
buf = (&pgproto3.Execute{}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf)
@ -893,6 +921,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
return nil, err
}
if ctx != nil {
select {
case <-ctx.Done():
pgConn.unlock()
@ -901,6 +930,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
// Send copy to command
buf := pgConn.wbuf
@ -952,6 +982,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
}
defer pgConn.unlock()
if ctx != nil {
select {
case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()}
@ -959,6 +990,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
// Send copy to command
buf := pgConn.wbuf
@ -1344,7 +1376,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
ctx: ctx,
}
multiResult := &pgConn.multiResultReader
if ctx != nil {
select {
case <-ctx.Done():
multiResult.closed = true
@ -1354,6 +1386,9 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
default:
}
pgConn.contextWatcher.Watch(ctx)
} else {
pgConn.multiResultReader.ctx = context.Background()
}
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)