Extracted more context handling

context
Jack Christensen 2017-02-07 21:49:58 -06:00
parent 004c18e5a2
commit 72b6d32e2f
3 changed files with 64 additions and 44 deletions

71
conn.go
View File

@ -90,8 +90,9 @@ type Conn struct {
causeOfDeath error causeOfDeath error
// context support // context support
doneChan chan struct{} ctxInProgress bool
closedChan chan struct{} doneChan chan struct{}
closedChan chan error
} }
// PreparedStatement is a description of a prepared statement // PreparedStatement is a description of a prepared statement
@ -262,7 +263,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.alive = true c.alive = true
c.lastActivityTime = time.Now() c.lastActivityTime = time.Now()
c.doneChan = make(chan struct{}) c.doneChan = make(chan struct{})
c.closedChan = make(chan struct{}) c.closedChan = make(chan error)
if tlsConfig != nil { if tlsConfig != nil {
if c.shouldLog(LogLevelDebug) { if c.shouldLog(LogLevelDebug) {
@ -629,22 +630,14 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
} }
func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
select { err = c.initContext(ctx)
case <-ctx.Done(): if err != nil {
return nil, ctx.Err() return nil, err
default:
} }
go c.contextHandler(ctx)
ps, err = c.prepareEx(name, sql, opts) ps, err = c.prepareEx(name, sql, opts)
err = c.termContext(err)
select { return ps, err
case <-c.closedChan:
return nil, ctx.Err()
case c.doneChan <- struct{}{}:
return ps, err
}
} }
func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
@ -1371,22 +1364,56 @@ func (c *Conn) PingContext(ctx context.Context) error {
} }
func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) {
err = c.initContext(ctx)
if err != nil {
return "", err
}
commandTag, err = c.Exec(sql, arguments...)
err = c.termContext(err)
return commandTag, err
}
func (c *Conn) initContext(ctx context.Context) error {
if c.ctxInProgress {
return errors.New("ctx already in progress")
}
if ctx.Done() == nil {
return nil
}
select { select {
case <-ctx.Done(): case <-ctx.Done():
return "", ctx.Err() return ctx.Err()
default: default:
} }
c.ctxInProgress = true
go c.contextHandler(ctx) go c.contextHandler(ctx)
commandTag, err = c.Exec(sql, arguments...) return nil
}
func (c *Conn) termContext(opErr error) error {
if !c.ctxInProgress {
return opErr
}
var err error
select { select {
case <-c.closedChan: case err = <-c.closedChan:
return "", ctx.Err() if opErr == nil {
err = nil
}
case c.doneChan <- struct{}{}: case c.doneChan <- struct{}{}:
return commandTag, err err = opErr
} }
c.ctxInProgress = false
return err
} }
func (c *Conn) contextHandler(ctx context.Context) { func (c *Conn) contextHandler(ctx context.Context) {
@ -1394,7 +1421,7 @@ func (c *Conn) contextHandler(ctx context.Context) {
case <-ctx.Done(): case <-ctx.Done():
c.cancelQuery() c.cancelQuery()
c.Close() c.Close()
c.closedChan <- struct{}{} c.closedChan <- ctx.Err()
case <-c.doneChan: case <-c.doneChan:
} }
} }

View File

@ -182,6 +182,10 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) {
// Release gives up use of a connection. // Release gives up use of a connection.
func (p *ConnPool) Release(conn *Conn) { func (p *ConnPool) Release(conn *Conn) {
if conn.ctxInProgress {
panic("should never release when context is in progress")
}
if conn.TxStatus != 'I' { if conn.TxStatus != 'I' {
conn.Exec("rollback") conn.Exec("rollback")
} }

View File

@ -50,8 +50,6 @@ type Rows struct {
afterClose func(*Rows) afterClose func(*Rows)
unlockConn bool unlockConn bool
closed bool closed bool
ctx context.Context
} }
func (rows *Rows) FieldDescriptions() []FieldDescription { func (rows *Rows) FieldDescriptions() []FieldDescription {
@ -84,6 +82,9 @@ func (rows *Rows) close() {
} }
} }
// TODO - consider inlining in Close(). This method calling rows.close is a
// foot-gun waiting to happen if anyone puts anything between the call to this
// and rows.close.
func (rows *Rows) readUntilReadyForQuery() { func (rows *Rows) readUntilReadyForQuery() {
for { for {
t, r, err := rows.conn.rxMsg() t, r, err := rows.conn.rxMsg()
@ -122,16 +123,8 @@ func (rows *Rows) Close() {
if rows.closed { if rows.closed {
return return
} }
rows.err = rows.conn.termContext(rows.err)
rows.readUntilReadyForQuery() rows.readUntilReadyForQuery()
if rows.ctx != nil {
select {
case <-rows.conn.closedChan:
rows.err = rows.ctx.Err()
case rows.conn.doneChan <- struct{}{}:
}
}
rows.close() rows.close()
} }
@ -506,20 +499,16 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row {
} }
func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) {
go c.contextHandler(ctx) err := c.initContext(ctx)
rows, err := c.Query(sql, args...)
if err != nil { if err != nil {
select { return nil, err
case <-c.closedChan:
return rows, ctx.Err()
case c.doneChan <- struct{}{}:
return rows, err
}
} }
rows.ctx = ctx rows, err := c.Query(sql, args...)
if err != nil {
err = c.termContext(err)
return nil, err
}
return rows, nil return rows, nil
} }