diff --git a/conn.go b/conn.go index 7243a4d1..f7443719 100644 --- a/conn.go +++ b/conn.go @@ -93,6 +93,8 @@ type Conn struct { status int32 // One of connStatus* constants causeOfDeath error + readyForQuery bool // can the connection be used to send a query + // context support ctxInProgress bool doneChan chan struct{} @@ -653,6 +655,10 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + if c.shouldLog(LogLevelError) { defer func() { if err != nil { @@ -692,6 +698,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared c.die(err) return nil, err } + c.readyForQuery = false ps = &PreparedStatement{Name: name, SQL: sql} @@ -706,7 +713,6 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared } switch t { - case parseComplete: case parameterDescription: ps.ParameterOids = c.rxParameterDescription(r) @@ -720,7 +726,6 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared ps.FieldDescriptions[i].DataTypeName = t.Name ps.FieldDescriptions[i].FormatCode = t.DefaultFormat } - case noData: case readyForQuery: c.rxReadyForQuery(r) @@ -739,6 +744,10 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared // Deallocate released a prepared statement func (c *Conn) Deallocate(name string) (err error) { + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err + } + delete(c.preparedStatements, name) // close @@ -809,6 +818,10 @@ func (c *Conn) WaitForNotification(timeout time.Duration) (*Notification, error) return notification, nil } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + stopTime := time.Now().Add(timeout) for { @@ -916,6 +929,9 @@ func (c *Conn) sendQuery(sql string, arguments ...interface{}) (err error) { } func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err + } if len(args) == 0 { wbuf := newWriteBuf(c, 'Q') @@ -927,6 +943,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { c.die(err) return err } + c.readyForQuery = false return nil } @@ -944,6 +961,10 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments)) } + if err := c.ensureConnectionReadyForQuery(); err != nil { + return err + } + // bind wbuf := newWriteBuf(c, 'B') wbuf.WriteByte(0) @@ -991,6 +1012,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} if err != nil { c.die(err) } + c.readyForQuery = false return err } @@ -1040,9 +1062,6 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag case readyForQuery: c.rxReadyForQuery(r) return commandTag, softErr - case rowDescription: - case dataRow: - case bindComplete: case commandComplete: commandTag = CommandTag(r.readCString()) default: @@ -1054,25 +1073,36 @@ func (c *Conn) Exec(sql string, arguments ...interface{}) (commandTag CommandTag } // Processes messages that are not exclusive to one context such as -// authentication or query response. The response to these messages -// is the same regardless of when they occur. +// authentication or query response. The response to these messages is the same +// regardless of when they occur. It also ignores messages that are only +// meaningful in a given context. These messages can occur do to a context +// deadline interrupting message processing. For example, an interrupted query +// may have left DataRow messages on the wire. func (c *Conn) processContextFreeMsg(t byte, r *msgReader) (err error) { switch t { - case 'S': - c.rxParameterStatus(r) - return nil + case bindComplete: + case commandComplete: + case dataRow: + case emptyQueryResponse: case errorResponse: return c.rxErrorResponse(r) + case noData: case noticeResponse: - return nil - case emptyQueryResponse: - return nil case notificationResponse: c.rxNotificationResponse(r) - return nil + case parameterDescription: + case parseComplete: + case readyForQuery: + c.rxReadyForQuery(r) + case rowDescription: + case 'S': + c.rxParameterStatus(r) + default: return fmt.Errorf("Received unknown message type: %c", t) } + + return nil } func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { @@ -1082,7 +1112,9 @@ func (c *Conn) rxMsg() (t byte, r *msgReader, err error) { t, err = c.mr.rxMsg() if err != nil { - c.die(err) + if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { + c.die(err) + } } c.lastActivityTime = time.Now() @@ -1183,6 +1215,7 @@ func (c *Conn) rxBackendKeyData(r *msgReader) { } func (c *Conn) rxReadyForQuery(r *msgReader) { + c.readyForQuery = true c.TxStatus = r.readByte() } @@ -1428,3 +1461,27 @@ func (c *Conn) contextHandler(ctx context.Context) { case <-c.doneChan: } } + +func (c *Conn) ensureConnectionReadyForQuery() error { + for !c.readyForQuery { + t, r, err := c.rxMsg() + if err != nil { + return err + } + + switch t { + case errorResponse: + pgErr := c.rxErrorResponse(r) + if pgErr.Severity == "FATAL" { + return pgErr + } + default: + err = c.processContextFreeMsg(t, r) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/conn_test.go b/conn_test.go index e92c7ca3..ca39b4b4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -872,7 +872,7 @@ func TestExecContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled err, got %v", err) } - ensureConnDeadOnServer(t, conn, *defaultConnConfig) + ensureConnValid(t, conn) } func TestPrepare(t *testing.T) { diff --git a/copy_to.go b/copy_to.go index 91292bb0..dd70ada3 100644 --- a/copy_to.go +++ b/copy_to.go @@ -66,7 +66,6 @@ func (ct *copyTo) readUntilReadyForQuery() { ct.conn.rxReadyForQuery(r) close(ct.readerErrChan) return - case commandComplete: case errorResponse: ct.readerErrChan <- ct.conn.rxErrorResponse(r) default: diff --git a/fastpath.go b/fastpath.go index 19b98784..30a9f102 100644 --- a/fastpath.go +++ b/fastpath.go @@ -48,6 +48,10 @@ func fpInt64Arg(n int64) fpArg { } func (f *fastpath) Call(oid Oid, args []fpArg) (res []byte, err error) { + if err := f.cn.ensureConnectionReadyForQuery(); err != nil { + return nil, err + } + wbuf := newWriteBuf(f.cn, 'F') // function call wbuf.WriteInt32(int32(oid)) // function object id wbuf.WriteInt16(1) // # of argument format codes diff --git a/helper_test.go b/helper_test.go index 997ae26f..21f86de5 100644 --- a/helper_test.go +++ b/helper_test.go @@ -71,19 +71,3 @@ func ensureConnValid(t *testing.T, conn *pgx.Conn) { t.Error("Wrong values returned") } } - -func ensureConnDeadOnServer(t *testing.T, conn *pgx.Conn, config pgx.ConnConfig) { - checkConn := mustConnect(t, config) - defer closeConn(t, checkConn) - - for i := 0; i < 10; i++ { - var found bool - err := checkConn.QueryRow("select true from pg_stat_activity where pid=$1", conn.Pid).Scan(&found) - if err == pgx.ErrNoRows { - return - } else if err != nil { - t.Fatalf("Unable to check if conn is dead on server: %v", err) - } - } - t.Fatal("Expected conn to be disconnected from server, but it wasn't") -} diff --git a/query.go b/query.go index 61136092..b6470688 100644 --- a/query.go +++ b/query.go @@ -82,41 +82,6 @@ func (rows *Rows) close() { } } -// TODO - consider inlining in Close(). This method calling rows.close is a -// foot-gun waiting to happen if anyone puts anything between the call to this -// and rows.close. -func (rows *Rows) readUntilReadyForQuery() { - for { - t, r, err := rows.conn.rxMsg() - if err != nil { - rows.close() - return - } - - switch t { - case readyForQuery: - rows.conn.rxReadyForQuery(r) - rows.close() - return - case rowDescription: - case dataRow: - case commandComplete: - case bindComplete: - case errorResponse: - err = rows.conn.rxErrorResponse(r) - if rows.err == nil { - rows.err = err - } - default: - err = rows.conn.processContextFreeMsg(t, r) - if err != nil { - rows.close() - return - } - } - } -} - // Close closes the rows, making the connection ready for use again. It is safe // to call Close after rows is already closed. func (rows *Rows) Close() { @@ -124,7 +89,6 @@ func (rows *Rows) Close() { return } rows.err = rows.conn.termContext(rows.err) - rows.readUntilReadyForQuery() rows.close() } @@ -174,10 +138,6 @@ func (rows *Rows) Next() bool { } switch t { - case readyForQuery: - rows.conn.rxReadyForQuery(r) - rows.close() - return false case dataRow: fieldCount := r.readInt16() if int(fieldCount) != len(rows.fields) { @@ -188,7 +148,9 @@ func (rows *Rows) Next() bool { rows.mr = r return true case commandComplete: - case bindComplete: + rows.close() + return false + default: err = rows.conn.processContextFreeMsg(t, r) if err != nil { diff --git a/query_test.go b/query_test.go index 40886f2e..24310ab3 100644 --- a/query_test.go +++ b/query_test.go @@ -1513,7 +1513,7 @@ func TestQueryContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled error, got %v", rows.Err()) } - ensureConnDeadOnServer(t, conn, *defaultConnConfig) + ensureConnValid(t, conn) } func TestQueryRowContextSuccess(t *testing.T) { @@ -1573,5 +1573,5 @@ func TestQueryRowContextCancelationCancelsQuery(t *testing.T) { t.Fatal("Expected context.Canceled error, got %v", err) } - ensureConnDeadOnServer(t, conn, *defaultConnConfig) + ensureConnValid(t, conn) }