context
Jack Christensen 2017-02-11 14:59:16 -06:00
parent f0dfe4fe89
commit e4f9108e82
7 changed files with 82 additions and 76 deletions

87
conn.go
View File

@ -93,6 +93,8 @@ type Conn struct {
status int32 // One of connStatus* constants status int32 // One of connStatus* constants
causeOfDeath error causeOfDeath error
readyForQuery bool // can the connection be used to send a query
// context support // context support
ctxInProgress bool ctxInProgress bool
doneChan chan struct{} doneChan chan struct{}
@ -653,6 +655,10 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
} }
} }
if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
if c.shouldLog(LogLevelError) { if c.shouldLog(LogLevelError) {
defer func() { defer func() {
if err != nil { if err != nil {
@ -692,6 +698,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
c.die(err) c.die(err)
return nil, err return nil, err
} }
c.readyForQuery = false
ps = &PreparedStatement{Name: name, SQL: sql} ps = &PreparedStatement{Name: name, SQL: sql}
@ -706,7 +713,6 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
} }
switch t { switch t {
case parseComplete:
case parameterDescription: case parameterDescription:
ps.ParameterOids = c.rxParameterDescription(r) ps.ParameterOids = c.rxParameterDescription(r)
@ -720,7 +726,6 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
ps.FieldDescriptions[i].DataTypeName = t.Name ps.FieldDescriptions[i].DataTypeName = t.Name
ps.FieldDescriptions[i].FormatCode = t.DefaultFormat ps.FieldDescriptions[i].FormatCode = t.DefaultFormat
} }
case noData:
case readyForQuery: case readyForQuery:
c.rxReadyForQuery(r) c.rxReadyForQuery(r)
@ -739,6 +744,10 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared
// Deallocate released a prepared statement // Deallocate released a prepared statement
func (c *Conn) Deallocate(name string) (err error) { func (c *Conn) Deallocate(name string) (err error) {
if err := c.ensureConnectionReadyForQuery(); err != nil {
return err
}
delete(c.preparedStatements, name) delete(c.preparedStatements, name)
// close // close
@ -809,6 +818,10 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error)
return notification, nil return notification, nil
} }
if err := c.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
stopTime := time.Now().Add(timeout) stopTime := time.Now().Add(timeout)
for { for {
@ -916,6 +929,9 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) {
} }
func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
if err := c.ensureConnectionReadyForQuery(); err != nil {
return err
}
if len(args) == 0 { if len(args) == 0 {
wbuf := newWriteBuf(c, 'Q') wbuf := newWriteBuf(c, 'Q')
@ -927,6 +943,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error {
c.die(err) c.die(err)
return err return err
} }
c.readyForQuery = false
return nil return nil
} }
@ -944,6 +961,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments)) return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments))
} }
if err := c.ensureConnectionReadyForQuery(); err != nil {
return err
}
// bind // bind
wbuf := newWriteBuf(c, 'B') wbuf := newWriteBuf(c, 'B')
wbuf.WriteByte(0) wbuf.WriteByte(0)
@ -991,6 +1012,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}
if err != nil { if err != nil {
c.die(err) c.die(err)
} }
c.readyForQuery = false
return err return err
} }
@ -1040,9 +1062,6 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
case readyForQuery: case readyForQuery:
c.rxReadyForQuery(r) c.rxReadyForQuery(r)
return commandTag, softErr return commandTag, softErr
case rowDescription:
case dataRow:
case bindComplete:
case commandComplete: case commandComplete:
commandTag = CommandTag(r.readCString()) commandTag = CommandTag(r.readCString())
default: default:
@ -1054,25 +1073,36 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag
} }
// Processes messages that are not exclusive to one context such as // Processes messages that are not exclusive to one context such as
// authentication or query response. The response to these messages // authentication or query response. The response to these messages is the same
// is the same regardless of when they occur. // regardless of when they occur. It also ignores messages that are only
// meaningful in a given context. These messages can occur do to a context
// deadline interrupting message processing. For example, an interrupted query
// may have left DataRow messages on the wire.
func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) {
switch t { switch t {
case 'S': case bindComplete:
c.rxParameterStatus(r) case commandComplete:
return nil case dataRow:
case emptyQueryResponse:
case errorResponse: case errorResponse:
return c.rxErrorResponse(r) return c.rxErrorResponse(r)
case noData:
case noticeResponse: case noticeResponse:
return nil
case emptyQueryResponse:
return nil
case notificationResponse: case notificationResponse:
c.rxNotificationResponse(r) c.rxNotificationResponse(r)
return nil case parameterDescription:
case parseComplete:
case readyForQuery:
c.rxReadyForQuery(r)
case rowDescription:
case 'S':
c.rxParameterStatus(r)
default: default:
return fmt.Errorf("Received unknown message type: %c", t) return fmt.Errorf("Received unknown message type: %c", t)
} }
return nil
} }
func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
@ -1082,7 +1112,9 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) {
t, err = c.mr.rxMsg() t, err = c.mr.rxMsg()
if err != nil { if err != nil {
c.die(err) if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) {
c.die(err)
}
} }
c.lastActivityTime = time.Now() c.lastActivityTime = time.Now()
@ -1183,6 +1215,7 @@ func (c *Conn) rxBackendKeyData(r *msgReader) {
} }
func (c *Conn) rxReadyForQuery(r *msgReader) { func (c *Conn) rxReadyForQuery(r *msgReader) {
c.readyForQuery = true
c.TxStatus = r.readByte() c.TxStatus = r.readByte()
} }
@ -1428,3 +1461,27 @@ func (c *Conn) contextHandler(ctx context.Context) {
case <-c.doneChan: case <-c.doneChan:
} }
} }
func (c *Conn) ensureConnectionReadyForQuery() error {
for !c.readyForQuery {
t, r, err := c.rxMsg()
if err != nil {
return err
}
switch t {
case errorResponse:
pgErr := c.rxErrorResponse(r)
if pgErr.Severity == "FATAL" {
return pgErr
}
default:
err = c.processContextFreeMsg(t, r)
if err != nil {
return err
}
}
}
return nil
}

View File

@ -872,7 +872,7 @@ func TestExecContextCancelationCancelsQuery(t *testing.T) {
t.Fatal("Expected context.Canceled err, got %v", err) t.Fatal("Expected context.Canceled err, got %v", err)
} }
ensureConnDeadOnServer(t, conn, *defaultConnConfig) ensureConnValid(t, conn)
} }
func TestPrepare(t *testing.T) { func TestPrepare(t *testing.T) {

View File

@ -66,7 +66,6 @@ func (ct *copyTo) readUntilReadyForQuery() {
ct.conn.rxReadyForQuery(r) ct.conn.rxReadyForQuery(r)
close(ct.readerErrChan) close(ct.readerErrChan)
return return
case commandComplete:
case errorResponse: case errorResponse:
ct.readerErrChan <- ct.conn.rxErrorResponse(r) ct.readerErrChan <- ct.conn.rxErrorResponse(r)
default: default:

View File

@ -48,6 +48,10 @@ func fpInt64Arg(n int64) fpArg {
} }
func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) { func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) {
if err := f.cn.ensureConnectionReadyForQuery(); err != nil {
return nil, err
}
wbuf := newWriteBuf(f.cn, 'F') // function call wbuf := newWriteBuf(f.cn, 'F') // function call
wbuf.WriteInt32(int32(oid)) // function object id wbuf.WriteInt32(int32(oid)) // function object id
wbuf.WriteInt16(1) // # of argument format codes wbuf.WriteInt16(1) // # of argument format codes

View File

@ -71,19 +71,3 @@ func ensureConnValid(t *testing.T, conn *pgx.Conn) {
t.Error("Wrong values returned") t.Error("Wrong values returned")
} }
} }
func ensureConnDeadOnServer(t *testing.T, conn *pgx.Conn, config pgx.ConnConfig) {
checkConn := mustConnect(t, config)
defer closeConn(t, checkConn)
for i := 0; i < 10; i++ {
var found bool
err := checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found)
if err == pgx.ErrNoRows {
return
} else if err != nil {
t.Fatalf("Unable to check if conn is dead on server: %v", err)
}
}
t.Fatal("Expected conn to be disconnected from server, but it wasn't")
}

View File

@ -82,41 +82,6 @@ func (rows *Rows) close() {
} }
} }
// TODO - consider inlining in Close(). This method calling rows.close is a
// foot-gun waiting to happen if anyone puts anything between the call to this
// and rows.close.
func (rows *Rows) readUntilReadyForQuery() {
for {
t, r, err := rows.conn.rxMsg()
if err != nil {
rows.close()
return
}
switch t {
case readyForQuery:
rows.conn.rxReadyForQuery(r)
rows.close()
return
case rowDescription:
case dataRow:
case commandComplete:
case bindComplete:
case errorResponse:
err = rows.conn.rxErrorResponse(r)
if rows.err == nil {
rows.err = err
}
default:
err = rows.conn.processContextFreeMsg(t, r)
if err != nil {
rows.close()
return
}
}
}
}
// Close closes the rows, making the connection ready for use again. It is safe // Close closes the rows, making the connection ready for use again. It is safe
// to call Close after rows is already closed. // to call Close after rows is already closed.
func (rows *Rows) Close() { func (rows *Rows) Close() {
@ -124,7 +89,6 @@ func (rows *Rows) Close() {
return return
} }
rows.err = rows.conn.termContext(rows.err) rows.err = rows.conn.termContext(rows.err)
rows.readUntilReadyForQuery()
rows.close() rows.close()
} }
@ -174,10 +138,6 @@ func (rows *Rows) Next() bool {
} }
switch t { switch t {
case readyForQuery:
rows.conn.rxReadyForQuery(r)
rows.close()
return false
case dataRow: case dataRow:
fieldCount := r.readInt16() fieldCount := r.readInt16()
if int(fieldCount) != len(rows.fields) { if int(fieldCount) != len(rows.fields) {
@ -188,7 +148,9 @@ func (rows *Rows) Next() bool {
rows.mr = r rows.mr = r
return true return true
case commandComplete: case commandComplete:
case bindComplete: rows.close()
return false
default: default:
err = rows.conn.processContextFreeMsg(t, r) err = rows.conn.processContextFreeMsg(t, r)
if err != nil { if err != nil {

View File

@ -1513,7 +1513,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) {
t.Fatal("Expected context.Canceled error, got %v", rows.Err()) t.Fatal("Expected context.Canceled error, got %v", rows.Err())
} }
ensureConnDeadOnServer(t, conn, *defaultConnConfig) ensureConnValid(t, conn)
} }
func TestQueryRowContextSuccess(t *testing.T) { func TestQueryRowContextSuccess(t *testing.T) {
@ -1573,5 +1573,5 @@ func TestQueryRowContextCancelationCancelsQuery(t *testing.T) {
t.Fatal("Expected context.Canceled error, got %v", err) t.Fatal("Expected context.Canceled error, got %v", err)
} }
ensureConnDeadOnServer(t, conn, *defaultConnConfig) ensureConnValid(t, conn)
} }