diff --git a/conn.go b/conn.go index 1f9cdad5..14bd7671 100644 --- a/conn.go +++ b/conn.go @@ -211,27 +211,28 @@ func Connect(config ConnConfig) (c *Conn, err error) { for { var t byte var r *MessageReader - if t, r, err = c.rxMsg(); err == nil { - switch t { - case backendKeyData: - c.rxBackendKeyData(r) - case authenticationX: - if err = c.rxAuthenticationX(r); err != nil { - return nil, err - } - case readyForQuery: - c.rxReadyForQuery(r) - c.logger = c.logger.New("pid", c.Pid) - c.logger.Info("Connection established") - return c, nil - default: - if err = c.processContextFreeMsg(t, r); err != nil { - return nil, err - } - } - } else { + t, r, err = c.rxMsg() + if err != nil { return nil, err } + + switch t { + case backendKeyData: + c.rxBackendKeyData(r) + case authenticationX: + if err = c.rxAuthenticationX(r); err != nil { + return nil, err + } + case readyForQuery: + c.rxReadyForQuery(r) + c.logger = c.logger.New("pid", c.Pid) + c.logger.Info("Connection established") + return c, nil + default: + if err = c.processContextFreeMsg(t, r); err != nil { + return nil, err + } + } } } @@ -305,31 +306,36 @@ func (c *Conn) selectFunc(sql string, onDataRow func(*DataRowReader) error, argu return } + var softErr error + for { - if t, r, rxErr := c.rxMsg(); rxErr == nil { - switch t { - case readyForQuery: - c.rxReadyForQuery(r) - return - case rowDescription: - fields = c.rxRowDescription(r) - case dataRow: - if err == nil { - var drr *DataRowReader - drr, err = newDataRowReader(r, fields) - if err == nil { - err = onDataRow(drr) - } - } - case commandComplete: - case bindComplete: - default: - if e := c.processContextFreeMsg(t, r); e != nil && err == nil { - err = e + var t byte + var r *MessageReader + t, r, err = c.rxMsg() + if err != nil { + return err + } + + switch t { + case readyForQuery: + c.rxReadyForQuery(r) + return softErr + case rowDescription: + fields = c.rxRowDescription(r) + case dataRow: + if softErr == nil { + var drr *DataRowReader + drr, softErr = newDataRowReader(r, fields) + if softErr == nil { + softErr = onDataRow(drr) } } - } else { - return rxErr + case commandComplete: + case bindComplete: + default: + if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { + softErr = e + } } } } @@ -456,44 +462,50 @@ func (c *Conn) SelectValueTo(w io.Writer, sql string, arguments ...interface{}) } var numRowsFound int64 + var softErr error for { - if t, bodySize, rxErr := c.rxMsgHeader(); rxErr == nil { - if t == dataRow { - numRowsFound++ + var t byte + var bodySize int32 - if numRowsFound > 1 { - err = NotSingleRowError{RowCount: numRowsFound} - } + t, bodySize, err = c.rxMsgHeader() + if err != nil { + return err + } - if err != nil { - c.rxMsgBody(bodySize) // Read and discard rest of message - continue - } + if t == dataRow { + numRowsFound++ - err = c.rxDataRowValueTo(w, bodySize) - } else { - var body *bytes.Buffer - if body, rxErr = c.rxMsgBody(bodySize); rxErr == nil { - r := newMessageReader(body) - switch t { - case readyForQuery: - c.rxReadyForQuery(r) - return - case rowDescription: - case commandComplete: - case bindComplete: - default: - if e := c.processContextFreeMsg(t, r); e != nil && err == nil { - err = e - } - } - } else { - return rxErr + if numRowsFound > 1 { + softErr = NotSingleRowError{RowCount: numRowsFound} + } + + if softErr != nil { + c.rxMsgBody(bodySize) // Read and discard rest of message + continue + } + + softErr = c.rxDataRowValueTo(w, bodySize) + } else { + var body *bytes.Buffer + body, err = c.rxMsgBody(bodySize) + if err != nil { + return err + } + + r := newMessageReader(body) + switch t { + case readyForQuery: + c.rxReadyForQuery(r) + return softErr + case rowDescription: + case commandComplete: + case bindComplete: + default: + if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { + softErr = e } } - } else { - return rxErr } } } @@ -610,32 +622,37 @@ func (c *Conn) Prepare(name, sql string) (err error) { ps := preparedStatement{Name: name} + var softErr error + for { - if t, r, rxErr := c.rxMsg(); rxErr == nil { - switch t { - case parseComplete: - case parameterDescription: - ps.ParameterOids = c.rxParameterDescription(r) - case rowDescription: - ps.FieldDescriptions = c.rxRowDescription(r) - for i := range ps.FieldDescriptions { - oid := ps.FieldDescriptions[i].DataType - if ValueTranscoders[oid] != nil && ValueTranscoders[oid].DecodeBinary != nil { - ps.FieldDescriptions[i].FormatCode = 1 - } - } - case noData: - case readyForQuery: - c.rxReadyForQuery(r) - c.preparedStatements[name] = &ps - return - default: - if e := c.processContextFreeMsg(t, r); e != nil && err == nil { - err = e + var t byte + var r *MessageReader + t, r, err := c.rxMsg() + if err != nil { + return err + } + + switch t { + case parseComplete: + case parameterDescription: + ps.ParameterOids = c.rxParameterDescription(r) + case rowDescription: + ps.FieldDescriptions = c.rxRowDescription(r) + for i := range ps.FieldDescriptions { + oid := ps.FieldDescriptions[i].DataType + if ValueTranscoders[oid] != nil && ValueTranscoders[oid].DecodeBinary != nil { + ps.FieldDescriptions[i].FormatCode = 1 } } - } else { - return rxErr + case noData: + case readyForQuery: + c.rxReadyForQuery(r) + c.preparedStatements[name] = &ps + return softErr + default: + if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { + softErr = e + } } } } @@ -844,24 +861,29 @@ func (c *Conn) Execute(sql string, arguments ...interface{}) (commandTag Command return } + var softErr error + for { - if t, r, rxErr := c.rxMsg(); rxErr == nil { - switch t { - case readyForQuery: - c.rxReadyForQuery(r) - return - case rowDescription: - case dataRow: - case bindComplete: - case commandComplete: - commandTag = CommandTag(r.ReadCString()) - default: - if e := c.processContextFreeMsg(t, r); e != nil && err == nil { - err = e - } + var t byte + var r *MessageReader + t, r, err = c.rxMsg() + if err != nil { + return commandTag, err + } + + switch t { + case readyForQuery: + c.rxReadyForQuery(r) + return commandTag, softErr + case rowDescription: + case dataRow: + case bindComplete: + case commandComplete: + commandTag = CommandTag(r.ReadCString()) + default: + if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { + softErr = e } - } else { - return "", rxErr } } }