Enable passing nil context

This commit is contained in:
bakape 2020-01-01 13:09:50 +02:00
parent 3e503b7b1a
commit 89416dd805
3 changed files with 116 additions and 77 deletions

1
.gitignore vendored
View File

@ -1,2 +1,3 @@
.envrc .envrc
vendor/ vendor/
.vscode

3
doc.go
View File

@ -23,6 +23,9 @@ Context Support
All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the All potentially blocking operations take a context.Context. If a context is canceled while the method is in progress the
method immediately returns. In most circumstances, this will close the underlying connection. method immediately returns. In most circumstances, this will close the underlying connection.
A nil context can be passed for convenience. This has the same effect as passing context.Background() with an additional
slight performance increase, if you don't need the operation to be cancellable.
The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the The CancelRequest method may be used to request the PostgreSQL server cancel an in-progress query without forcing the
client to abort. client to abort.
*/ */

View File

@ -116,6 +116,10 @@ func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err err
panic("config must be created by ParseConfig") panic("config must be created by ParseConfig")
} }
if ctx == nil {
ctx = context.Background()
}
// Simplify usage by treating primary config and fallbacks the same. // Simplify usage by treating primary config and fallbacks the same.
fallbackConfigs := []*FallbackConfig{ fallbackConfigs := []*FallbackConfig{
{ {
@ -362,6 +366,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
} }
defer pgConn.unlock() defer pgConn.unlock()
if ctx != nil {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return &contextAlreadyDoneError{err: ctx.Err()} return &contextAlreadyDoneError{err: ctx.Err()}
@ -369,6 +374,7 @@ func (pgConn *PgConn) SendBytes(ctx context.Context, buf []byte) error {
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
}
n, err := pgConn.conn.Write(buf) n, err := pgConn.conn.Write(buf)
if err != nil { if err != nil {
@ -392,6 +398,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
} }
defer pgConn.unlock() defer pgConn.unlock()
if ctx != nil {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()} return nil, &contextAlreadyDoneError{err: ctx.Err()}
@ -399,6 +406,7 @@ func (pgConn *PgConn) ReceiveMessage(ctx context.Context) (pgproto3.BackendMessa
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
}
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
if err != nil { if err != nil {
@ -489,8 +497,10 @@ func (pgConn *PgConn) Close(ctx context.Context) error {
defer pgConn.conn.Close() defer pgConn.conn.Close()
if ctx != nil {
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
}
// Ignore any errors sending Terminate message and waiting for server to close connection. // Ignore any errors sending Terminate message and waiting for server to close connection.
// This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully // This mimics the behavior of libpq PQfinish. It calls closePGconn which calls sendTerminateConn which purposefully
@ -586,6 +596,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
} }
defer pgConn.unlock() defer pgConn.unlock()
if ctx != nil {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()} return nil, &contextAlreadyDoneError{err: ctx.Err()}
@ -593,6 +604,7 @@ func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs [
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
}
buf := pgConn.wbuf buf := pgConn.wbuf
buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
@ -673,18 +685,24 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error {
// the connection config. This is important in high availability configurations where fallback connections may be // the connection config. This is important in high availability configurations where fallback connections may be
// specified or DNS may be used to load balance. // specified or DNS may be used to load balance.
serverAddr := pgConn.conn.RemoteAddr() serverAddr := pgConn.conn.RemoteAddr()
cancelConn, err := pgConn.config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) _ctx := ctx
if _ctx == nil {
_ctx = context.Background()
}
cancelConn, err := pgConn.config.DialFunc(_ctx, serverAddr.Network(), serverAddr.String())
if err != nil { if err != nil {
return err return err
} }
defer cancelConn.Close() defer cancelConn.Close()
if ctx != nil {
contextWatcher := ctxwatch.NewContextWatcher( contextWatcher := ctxwatch.NewContextWatcher(
func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
func() { cancelConn.SetDeadline(time.Time{}) }, func() { cancelConn.SetDeadline(time.Time{}) },
) )
contextWatcher.Watch(ctx) contextWatcher.Watch(ctx)
defer contextWatcher.Unwatch() defer contextWatcher.Unwatch()
}
buf := make([]byte, 16) buf := make([]byte, 16)
binary.BigEndian.PutUint32(buf[0:4], 16) binary.BigEndian.PutUint32(buf[0:4], 16)
@ -712,6 +730,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
} }
defer pgConn.unlock() defer pgConn.unlock()
if ctx != nil {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
@ -720,6 +739,7 @@ func (pgConn *PgConn) WaitForNotification(ctx context.Context) error {
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
}
for { for {
msg, err := pgConn.receiveMessage() msg, err := pgConn.receiveMessage()
@ -752,7 +772,7 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
ctx: ctx, ctx: ctx,
} }
multiResult := &pgConn.multiResultReader multiResult := &pgConn.multiResultReader
if ctx != nil {
select { select {
case <-ctx.Done(): case <-ctx.Done():
multiResult.closed = true multiResult.closed = true
@ -762,6 +782,9 @@ func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader {
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
} else {
pgConn.multiResultReader.ctx = context.Background()
}
buf := pgConn.wbuf buf := pgConn.wbuf
buf = (&pgproto3.Query{String: sql}).Encode(buf) buf = (&pgproto3.Query{String: sql}).Encode(buf)
@ -808,7 +831,7 @@ func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues []
buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf)
buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
pgConn.execExtendedSuffix(ctx, buf, result) pgConn.execExtendedSuffix(buf, result)
return result return result
} }
@ -834,7 +857,7 @@ func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramVa
buf := pgConn.wbuf buf := pgConn.wbuf
buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf)
pgConn.execExtendedSuffix(ctx, buf, result) pgConn.execExtendedSuffix(buf, result)
return result return result
} }
@ -845,6 +868,9 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
ctx: ctx, ctx: ctx,
} }
result := &pgConn.resultReader result := &pgConn.resultReader
if ctx == nil {
pgConn.resultReader.ctx = context.Background()
}
if err := pgConn.lock(); err != nil { if err := pgConn.lock(); err != nil {
result.concludeCommand(nil, err) result.concludeCommand(nil, err)
@ -859,6 +885,7 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
return result return result
} }
if ctx != nil {
select { select {
case <-ctx.Done(): case <-ctx.Done():
result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()}) result.concludeCommand(nil, &contextAlreadyDoneError{err: ctx.Err()})
@ -868,11 +895,12 @@ func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]by
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
}
return result return result
} }
func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) { func (pgConn *PgConn) execExtendedSuffix(buf []byte, result *ResultReader) {
buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf)
buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf)
buf = (&pgproto3.Sync{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf)
@ -893,6 +921,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
return nil, err return nil, err
} }
if ctx != nil {
select { select {
case <-ctx.Done(): case <-ctx.Done():
pgConn.unlock() pgConn.unlock()
@ -901,6 +930,7 @@ func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (Comm
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
}
// Send copy to command // Send copy to command
buf := pgConn.wbuf buf := pgConn.wbuf
@ -952,6 +982,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
} }
defer pgConn.unlock() defer pgConn.unlock()
if ctx != nil {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, &contextAlreadyDoneError{err: ctx.Err()} return nil, &contextAlreadyDoneError{err: ctx.Err()}
@ -959,6 +990,7 @@ func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (Co
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
defer pgConn.contextWatcher.Unwatch() defer pgConn.contextWatcher.Unwatch()
}
// Send copy to command // Send copy to command
buf := pgConn.wbuf buf := pgConn.wbuf
@ -1344,7 +1376,7 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
ctx: ctx, ctx: ctx,
} }
multiResult := &pgConn.multiResultReader multiResult := &pgConn.multiResultReader
if ctx != nil {
select { select {
case <-ctx.Done(): case <-ctx.Done():
multiResult.closed = true multiResult.closed = true
@ -1354,6 +1386,9 @@ func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultR
default: default:
} }
pgConn.contextWatcher.Watch(ctx) pgConn.contextWatcher.Watch(ctx)
} else {
pgConn.multiResultReader.ctx = context.Background()
}
batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf)