From 72b6d32e2f841e6be96c5602c248b2875d345c3c Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Tue, 7 Feb 2017 21:49:58 -0600 Subject: [PATCH] Extracted more context handling --- conn.go | 71 ++++++++++++++++++++++++++++++++++++---------------- conn_pool.go | 4 +++ query.go | 33 ++++++++---------------- 3 files changed, 64 insertions(+), 44 deletions(-) diff --git a/conn.go b/conn.go index 453f1a51..b662ba4c 100644 --- a/conn.go +++ b/conn.go @@ -90,8 +90,9 @@ type Conn struct { causeOfDeath error // context support - doneChan chan struct{} - closedChan chan struct{} + ctxInProgress bool + doneChan chan struct{} + closedChan chan error } // PreparedStatement is a description of a prepared statement @@ -262,7 +263,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl c.alive = true c.lastActivityTime = time.Now() c.doneChan = make(chan struct{}) - c.closedChan = make(chan struct{}) + c.closedChan = make(chan error) if tlsConfig != nil { if c.shouldLog(LogLevelDebug) { @@ -629,22 +630,14 @@ func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } func (c *Conn) PrepareExContext(ctx context.Context, name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: + err = c.initContext(ctx) + if err != nil { + return nil, err } - go c.contextHandler(ctx) - ps, err = c.prepareEx(name, sql, opts) - - select { - case <-c.closedChan: - return nil, ctx.Err() - case c.doneChan <- struct{}{}: - return ps, err - } + err = c.termContext(err) + return ps, err } func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { @@ -1371,22 +1364,56 @@ func (c *Conn) PingContext(ctx context.Context) error { } func (c *Conn) ExecContext(ctx context.Context, sql string, arguments ...interface{}) (commandTag CommandTag, err error) { + err = c.initContext(ctx) + if err != nil { + return "", err + } + + commandTag, err = c.Exec(sql, arguments...) + err = c.termContext(err) + return commandTag, err +} + +func (c *Conn) initContext(ctx context.Context) error { + if c.ctxInProgress { + return errors.New("ctx already in progress") + } + + if ctx.Done() == nil { + return nil + } + select { case <-ctx.Done(): - return "", ctx.Err() + return ctx.Err() default: } + c.ctxInProgress = true + go c.contextHandler(ctx) - commandTag, err = c.Exec(sql, arguments...) + return nil +} + +func (c *Conn) termContext(opErr error) error { + if !c.ctxInProgress { + return opErr + } + + var err error select { - case <-c.closedChan: - return "", ctx.Err() + case err = <-c.closedChan: + if opErr == nil { + err = nil + } case c.doneChan <- struct{}{}: - return commandTag, err + err = opErr } + + c.ctxInProgress = false + return err } func (c *Conn) contextHandler(ctx context.Context) { @@ -1394,7 +1421,7 @@ func (c *Conn) contextHandler(ctx context.Context) { case <-ctx.Done(): c.cancelQuery() c.Close() - c.closedChan <- struct{}{} + c.closedChan <- ctx.Err() case <-c.doneChan: } } diff --git a/conn_pool.go b/conn_pool.go index 50b9d588..2a243a76 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -182,6 +182,10 @@ func (p *ConnPool) acquire(deadline *time.Time) (*Conn, error) { // Release gives up use of a connection. func (p *ConnPool) Release(conn *Conn) { + if conn.ctxInProgress { + panic("should never release when context is in progress") + } + if conn.TxStatus != 'I' { conn.Exec("rollback") } diff --git a/query.go b/query.go index daf1b354..61136092 100644 --- a/query.go +++ b/query.go @@ -50,8 +50,6 @@ type Rows struct { afterClose func(*Rows) unlockConn bool closed bool - - ctx context.Context } func (rows *Rows) FieldDescriptions() []FieldDescription { @@ -84,6 +82,9 @@ func (rows *Rows) close() { } } +// TODO - consider inlining in Close(). This method calling rows.close is a +// foot-gun waiting to happen if anyone puts anything between the call to this +// and rows.close. func (rows *Rows) readUntilReadyForQuery() { for { t, r, err := rows.conn.rxMsg() @@ -122,16 +123,8 @@ func (rows *Rows) Close() { if rows.closed { return } + rows.err = rows.conn.termContext(rows.err) rows.readUntilReadyForQuery() - - if rows.ctx != nil { - select { - case <-rows.conn.closedChan: - rows.err = rows.ctx.Err() - case rows.conn.doneChan <- struct{}{}: - } - } - rows.close() } @@ -506,20 +499,16 @@ func (c *Conn) QueryRow(sql string, args ...interface{}) *Row { } func (c *Conn) QueryContext(ctx context.Context, sql string, args ...interface{}) (*Rows, error) { - go c.contextHandler(ctx) - - rows, err := c.Query(sql, args...) - + err := c.initContext(ctx) if err != nil { - select { - case <-c.closedChan: - return rows, ctx.Err() - case c.doneChan <- struct{}{}: - return rows, err - } + return nil, err } - rows.ctx = ctx + rows, err := c.Query(sql, args...) + if err != nil { + err = c.termContext(err) + return nil, err + } return rows, nil }