mirror of https://github.com/jackc/pgx.git
Extracted more context handling
parent
004c18e5a2
commit
72b6d32e2f
71
conn.go
71
conn.go
|
@ -90,8 +90,9 @@ type Conn struct {
|
||||||
causeOfDeath error
|
causeOfDeath error
|
||||||
|
|
||||||
// context support
|
// context support
|
||||||
doneChan chan struct{}
|
ctxInProgress bool
|
||||||
closedChan chan struct{}
|
doneChan chan struct{}
|
||||||
|
closedChan chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
// PreparedStatement is a description of a prepared statement
|
// PreparedStatement is a description of a prepared statement
|
||||||
|
@ -262,7 +263,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
||||||
c.alive = true
|
c.alive = true
|
||||||
c.lastActivityTime = time.Now()
|
c.lastActivityTime = time.Now()
|
||||||
c.doneChan = make(chan struct{})
|
c.doneChan = make(chan struct{})
|
||||||
c.closedChan = make(chan struct{})
|
c.closedChan = make(chan error)
|
||||||
|
|
||||||
if tlsConfig != nil {
|
if tlsConfig != nil {
|
||||||
if c.shouldLog(LogLevelDebug) {
|
if c.shouldLog(LogLevelDebug) {
|
||||||
|
@ -629,22 +630,14 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
||||||
select {
|
err = c.initContext(ctx)
|
||||||
case <-ctx.Done():
|
if err != nil {
|
||||||
return nil, ctx.Err()
|
return nil, err
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
go c.contextHandler(ctx)
|
|
||||||
|
|
||||||
ps, err = c.prepareEx(name, sql, opts)
|
ps, err = c.prepareEx(name, sql, opts)
|
||||||
|
err = c.termContext(err)
|
||||||
select {
|
return ps, err
|
||||||
case <-c.closedChan:
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case c.doneChan <- struct{}{}:
|
|
||||||
return ps, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
|
||||||
|
@ -1371,22 +1364,56 @@ func (c *Conn) PingContext(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
|
||||||
|
err = c.initContext(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
commandTag, err = c.Exec(sql, arguments...)
|
||||||
|
err = c.termContext(err)
|
||||||
|
return commandTag, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) initContext(ctx context.Context) error {
|
||||||
|
if c.ctxInProgress {
|
||||||
|
return errors.New("ctx already in progress")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Done() == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return "", ctx.Err()
|
return ctx.Err()
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.ctxInProgress = true
|
||||||
|
|
||||||
go c.contextHandler(ctx)
|
go c.contextHandler(ctx)
|
||||||
|
|
||||||
commandTag, err = c.Exec(sql, arguments...)
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) termContext(opErr error) error {
|
||||||
|
if !c.ctxInProgress {
|
||||||
|
return opErr
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-c.closedChan:
|
case err = <-c.closedChan:
|
||||||
return "", ctx.Err()
|
if opErr == nil {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
case c.doneChan <- struct{}{}:
|
case c.doneChan <- struct{}{}:
|
||||||
return commandTag, err
|
err = opErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.ctxInProgress = false
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) contextHandler(ctx context.Context) {
|
func (c *Conn) contextHandler(ctx context.Context) {
|
||||||
|
@ -1394,7 +1421,7 @@ func (c *Conn) contextHandler(ctx context.Context) {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
c.cancelQuery()
|
c.cancelQuery()
|
||||||
c.Close()
|
c.Close()
|
||||||
c.closedChan <- struct{}{}
|
c.closedChan <- ctx.Err()
|
||||||
case <-c.doneChan:
|
case <-c.doneChan:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -182,6 +182,10 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
|
||||||
|
|
||||||
// Release gives up use of a connection.
|
// Release gives up use of a connection.
|
||||||
func (p *ConnPool) Release(conn *Conn) {
|
func (p *ConnPool) Release(conn *Conn) {
|
||||||
|
if conn.ctxInProgress {
|
||||||
|
panic("should never release when context is in progress")
|
||||||
|
}
|
||||||
|
|
||||||
if conn.TxStatus != 'I' {
|
if conn.TxStatus != 'I' {
|
||||||
conn.Exec("rollback")
|
conn.Exec("rollback")
|
||||||
}
|
}
|
||||||
|
|
33
query.go
33
query.go
|
@ -50,8 +50,6 @@ type Rows struct {
|
||||||
afterClose func(*Rows)
|
afterClose func(*Rows)
|
||||||
unlockConn bool
|
unlockConn bool
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
ctx context.Context
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rows *Rows) FieldDescriptions() []FieldDescription {
|
func (rows *Rows) FieldDescriptions() []FieldDescription {
|
||||||
|
@ -84,6 +82,9 @@ func (rows *Rows) close() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO - consider inlining in Close(). This method calling rows.close is a
|
||||||
|
// foot-gun waiting to happen if anyone puts anything between the call to this
|
||||||
|
// and rows.close.
|
||||||
func (rows *Rows) readUntilReadyForQuery() {
|
func (rows *Rows) readUntilReadyForQuery() {
|
||||||
for {
|
for {
|
||||||
t, r, err := rows.conn.rxMsg()
|
t, r, err := rows.conn.rxMsg()
|
||||||
|
@ -122,16 +123,8 @@ func (rows *Rows) Close() {
|
||||||
if rows.closed {
|
if rows.closed {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
rows.err = rows.conn.termContext(rows.err)
|
||||||
rows.readUntilReadyForQuery()
|
rows.readUntilReadyForQuery()
|
||||||
|
|
||||||
if rows.ctx != nil {
|
|
||||||
select {
|
|
||||||
case <-rows.conn.closedChan:
|
|
||||||
rows.err = rows.ctx.Err()
|
|
||||||
case rows.conn.doneChan <- struct{}{}:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rows.close()
|
rows.close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -506,20 +499,16 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) {
|
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) {
|
||||||
go c.contextHandler(ctx)
|
err := c.initContext(ctx)
|
||||||
|
|
||||||
rows, err := c.Query(sql, args...)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
select {
|
return nil, err
|
||||||
case <-c.closedChan:
|
|
||||||
return rows, ctx.Err()
|
|
||||||
case c.doneChan <- struct{}{}:
|
|
||||||
return rows, err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rows.ctx = ctx
|
rows, err := c.Query(sql, args...)
|
||||||
|
if err != nil {
|
||||||
|
err = c.termContext(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return rows, nil
|
return rows, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue