Enable passing nil context

query-exec-mode
bakape 2020-01-01 13:09:50 +02:00
parent 3e503b7b1a
commit 89416dd805
3 changed files with 116 additions and 77 deletions

3
.gitignore vendored
View File

@ -1,2 +1,3 @@
.envrc
vendor/
vendor/
.vscode

3
doc.go
View File

@ -23,6 +23,9 @@ Context Support
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
method immediately returns. In most circumstances, this will close the underlying connection.
A nil context can be passed for convenience. This has the same effect as passing context.Background() with an additional
slight performance increase, if you don't need the operation to be cancellable.
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
client to abort.
*/

187
pgconn.go
View File

@ -116,6 +116,10 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
panic("config must be created by ParseConfig")
}
if ctx == nil {
ctx = context.Background()
}
// Simplify usage by treating primary config and fallbacks the same.
fallbackConfigs := []*FallbackConfig{
{
@ -362,13 +366,15 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return &contextAlreadyDoneError{err: ctx.Err()}
default:
if ctx != nil {
select {
case <-ctx.Done():
return &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
n, err := pgConn.conn.Write(buf)
if err != nil {
@ -392,13 +398,15 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
if ctx != nil {
select {
case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
msg, err := pgConn.receiveMessage()
if err != nil {
@ -489,8 +497,10 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
defer pgConn.conn.Close()
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
if ctx != nil {
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
// Ignore any errors sending Terminate message and waiting for server to close connection.
// This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully
@ -586,13 +596,15 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
if ctx != nil {
select {
case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
buf := pgConn.wbuf
buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
@ -673,18 +685,24 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// the connection config. This is important in high availability configurations where fallback connections may be
// specified or DNS may be used to load balance.
serverAddr := pgConn.conn.RemoteAddr()
cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String())
_ctx := ctx
if _ctx == nil {
_ctx = context.Background()
}
cancelConn, err := pgConn.config.DialFunc(_ctx, serverAddr.Network(), serverAddr.String())
if err != nil {
return err
}
defer cancelConn.Close()
contextWatcher := ctxwatch.NewContextWatcher(
func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
func() { cancelConn.SetDeadline(time.Time{}) },
)
contextWatcher.Watch(ctx)
defer contextWatcher.Unwatch()
if ctx != nil {
contextWatcher := ctxwatch.NewContextWatcher(
func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
func() { cancelConn.SetDeadline(time.Time{}) },
)
contextWatcher.Watch(ctx)
defer contextWatcher.Unwatch()
}
buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16)
@ -712,14 +730,16 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if ctx != nil {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
for {
msg, err := pgConn.receiveMessage()
@ -752,16 +772,19 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
ctx: ctx,
}
multiResult := &pgConn.multiResultReader
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
pgConn.unlock()
return multiResult
default:
if ctx != nil {
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
pgConn.unlock()
return multiResult
default:
}
pgConn.contextWatcher.Watch(ctx)
} else {
pgConn.multiResultReader.ctx = context.Background()
}
pgConn.contextWatcher.Watch(ctx)
buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf)
@ -808,7 +831,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
pgConn.execExtendedSuffix(ctx, buf, result)
pgConn.execExtendedSuffix(buf, result)
return result
}
@ -834,7 +857,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
buf := pgConn.wbuf
buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
pgConn.execExtendedSuffix(ctx, buf, result)
pgConn.execExtendedSuffix(buf, result)
return result
}
@ -845,6 +868,9 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
ctx: ctx,
}
result := &pgConn.resultReader
if ctx == nil {
pgConn.resultReader.ctx = context.Background()
}
if err := pgConn.lock(); err != nil {
result.concludeCommand(nil, err)
@ -859,20 +885,22 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
return result
}
select {
case <-ctx.Done():
result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()})
result.closed = true
pgConn.unlock()
return result
default:
if ctx != nil {
select {
case <-ctx.Done():
result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()})
result.closed = true
pgConn.unlock()
return result
default:
}
pgConn.contextWatcher.Watch(ctx)
}
pgConn.contextWatcher.Watch(ctx)
return result
}
func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) {
func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf)
buf = (&pgproto3.Execute{}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf)
@ -893,14 +921,16 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
return nil, err
}
select {
case <-ctx.Done():
pgConn.unlock()
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
if ctx != nil {
select {
case <-ctx.Done():
pgConn.unlock()
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
// Send copy to command
buf := pgConn.wbuf
@ -952,13 +982,15 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
}
defer pgConn.unlock()
select {
case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
if ctx != nil {
select {
case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()}
default:
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
}
pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch()
// Send copy to command
buf := pgConn.wbuf
@ -1344,16 +1376,19 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
ctx: ctx,
}
multiResult := &pgConn.multiResultReader
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
pgConn.unlock()
return multiResult
default:
if ctx != nil {
select {
case <-ctx.Done():
multiResult.closed = true
multiResult.err = &contextAlreadyDoneError{err: ctx.Err()}
pgConn.unlock()
return multiResult
default:
}
pgConn.contextWatcher.Watch(ctx)
} else {
pgConn.multiResultReader.ctx = context.Background()
}
pgConn.contextWatcher.Watch(ctx)
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)