diff --git a/base/conn.go b/base/conn.go new file mode 100644 index 00000000..7434f720 --- /dev/null +++ b/base/conn.go @@ -0,0 +1,33 @@ +package base + +import ( + "net" + + "github.com/jackc/pgx/pgproto3" +) + +// Conn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. +type Conn struct { + NetConn net.Conn // the underlying TCP or unix domain socket connection + PID uint32 // backend pid + SecretKey uint32 // key to use to send a cancel query message to the server + RuntimeParams map[string]string // parameters that have been reported by the server + TxStatus byte + Frontend *pgproto3.Frontend +} + +func (conn *Conn) ReceiveMessage() (pgproto3.BackendMessage, error) { + msg, err := conn.Frontend.Receive() + if err != nil { + return nil, err + } + + switch msg := msg.(type) { + case *pgproto3.ReadyForQuery: + conn.TxStatus = msg.TxStatus + case *pgproto3.ParameterStatus: + conn.RuntimeParams[msg.Name] = msg.Value + } + + return msg, nil +} diff --git a/batch.go b/batch.go index 0d7f14cc..f7558030 100644 --- a/batch.go +++ b/batch.go @@ -133,7 +133,7 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { b.conn.pendingReadyForQueryCount++ } - n, err := b.conn.conn.Write(buf) + n, err := b.conn.BaseConn.NetConn.Write(buf) if err != nil { if fatalWriteErr(n, err) { b.conn.die(err) diff --git a/conn.go b/conn.go index 653c63a9..f7c9c85d 100644 --- a/conn.go +++ b/conn.go @@ -24,6 +24,7 @@ import ( "github.com/pkg/errors" + "github.com/jackc/pgx/base" "github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgtype" @@ -111,13 +112,9 @@ func (cc *ConnConfig) networkAddress() (network, address string) { // Use ConnPool to manage access to multiple database connections from multiple // goroutines. type Conn struct { - conn net.Conn // the underlying TCP or unix domain socket connection + BaseConn base.Conn wbuf []byte - pid uint32 // backend pid - secretKey uint32 // key to use to send a cancel query message to the server - RuntimeParams map[string]string // parameters that have been reported by the server - config ConnConfig // config used when establishing this connection - txStatus byte + config ConnConfig // config used when establishing this connection preparedStatements map[string]*PreparedStatement channels map[string]struct{} notifications []*Notification @@ -141,8 +138,6 @@ type Conn struct { closedChan chan error ConnInfo *pgtype.ConnInfo - - frontend *pgproto3.Frontend } // PreparedStatement is a description of a prepared statement @@ -290,20 +285,21 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { - c.conn, err = c.config.Dial(network, address) + c.BaseConn = base.Conn{} + c.BaseConn.NetConn, err = c.config.Dial(network, address) if err != nil { return err } defer func() { if c != nil && err != nil { - c.conn.Close() + c.BaseConn.NetConn.Close() c.mux.Lock() c.status = connStatusClosed c.mux.Unlock() } }() - c.RuntimeParams = make(map[string]string) + c.BaseConn.RuntimeParams = make(map[string]string) c.preparedStatements = make(map[string]*PreparedStatement) c.channels = make(map[string]struct{}) c.cancelQueryCompleted = make(chan struct{}) @@ -325,7 +321,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl } } - c.frontend, err = pgproto3.NewFrontend(c.conn, c.conn) + c.BaseConn.Frontend, err = pgproto3.NewFrontend(c.BaseConn.NetConn, c.BaseConn.NetConn) if err != nil { return err } @@ -345,7 +341,7 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl startupMsg.Parameters["database"] = c.config.Database } - if _, err := c.conn.Write(startupMsg.Encode(nil)); err != nil { + if _, err := c.BaseConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { return err } @@ -359,7 +355,8 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl switch msg := msg.(type) { case *pgproto3.BackendKeyData: - c.rxBackendKeyData(msg) + c.BaseConn.PID = msg.ProcessID + c.BaseConn.SecretKey = msg.SecretKey case *pgproto3.Authentication: if err = c.rxAuthenticationX(msg); err != nil { return err @@ -607,7 +604,7 @@ func (c *Conn) crateDBTypesQuery(err error) (*pgtype.ConnInfo, error) { // PID returns the backend PID for this connection. func (c *Conn) PID() uint32 { - return c.pid + return c.BaseConn.PID } // LocalAddr returns the underlying connection's local address @@ -615,7 +612,7 @@ func (c *Conn) LocalAddr() (net.Addr, error) { if !c.IsAlive() { return nil, errors.New("connection not ready") } - return c.conn.LocalAddr(), nil + return c.BaseConn.NetConn.LocalAddr(), nil } // Close closes a connection. It is safe to call Close on a already closed @@ -630,32 +627,32 @@ func (c *Conn) Close() (err error) { c.status = connStatusClosed defer func() { - c.conn.Close() + c.BaseConn.NetConn.Close() c.causeOfDeath = errors.New("Closed") if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "closed connection", nil) } }() - err = c.conn.SetDeadline(time.Time{}) + err = c.BaseConn.NetConn.SetDeadline(time.Time{}) if err != nil && c.shouldLog(LogLevelWarn) { c.log(LogLevelWarn, "failed to clear deadlines to send close message", map[string]interface{}{"err": err}) return err } - _, err = c.conn.Write([]byte{'X', 0, 0, 0, 4}) + _, err = c.BaseConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil && c.shouldLog(LogLevelWarn) { c.log(LogLevelWarn, "failed to send terminate message", map[string]interface{}{"err": err}) return err } - err = c.conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + err = c.BaseConn.NetConn.SetReadDeadline(time.Now().Add(5 * time.Second)) if err != nil && c.shouldLog(LogLevelWarn) { c.log(LogLevelWarn, "failed to set read deadline to finish closing", map[string]interface{}{"err": err}) return err } - _, err = c.conn.Read(make([]byte, 1)) + _, err = c.BaseConn.NetConn.Read(make([]byte, 1)) if err != io.EOF { return err } @@ -1093,7 +1090,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared buf = appendDescribe(buf, 'S', name) buf = appendSync(buf) - n, err := c.conn.Write(buf) + n, err := c.BaseConn.NetConn.Write(buf) if err != nil { if fatalWriteErr(n, err) { c.die(err) @@ -1189,7 +1186,7 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { buf = append(buf, 'H') buf = pgio.AppendInt32(buf, 4) - _, err = c.conn.Write(buf) + _, err = c.BaseConn.NetConn.Write(buf) if err != nil { c.die(err) return err @@ -1317,7 +1314,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { if len(args) == 0 { buf := appendQuery(c.wbuf, sql) - _, err := c.conn.Write(buf) + _, err := c.BaseConn.NetConn.Write(buf) if err != nil { c.die(err) return err @@ -1356,7 +1353,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} buf = appendExecute(buf, "", 0) buf = appendSync(buf) - n, err := c.conn.Write(buf) + n, err := c.BaseConn.NetConn.Write(buf) if err != nil { if fatalWriteErr(n, err) { c.die(err) @@ -1401,8 +1398,6 @@ func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) { c.rxNotificationResponse(msg) case *pgproto3.ReadyForQuery: c.rxReadyForQuery(msg) - case *pgproto3.ParameterStatus: - c.rxParameterStatus(msg) } return nil @@ -1413,7 +1408,7 @@ func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) { return nil, ErrDeadConn } - msg, err := c.frontend.Receive() + msg, err := c.BaseConn.ReceiveMessage() if err != nil { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { c.die(err) @@ -1421,8 +1416,6 @@ func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) { return nil, err } - // fmt.Printf("rxMsg: %#v\n", msg) - return msg, nil } @@ -1447,10 +1440,6 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } -func (c *Conn) rxParameterStatus(msg *pgproto3.ParameterStatus) { - c.RuntimeParams[msg.Name] = msg.Value -} - func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) PgError { err := PgError{ Severity: msg.Severity, @@ -1507,14 +1496,8 @@ func (c *Conn) rxNoticeResponse(msg *pgproto3.NoticeResponse) { c.onNotice(c, notice) } -func (c *Conn) rxBackendKeyData(msg *pgproto3.BackendKeyData) { - c.pid = msg.ProcessID - c.secretKey = msg.SecretKey -} - func (c *Conn) rxReadyForQuery(msg *pgproto3.ReadyForQuery) { c.pendingReadyForQueryCount-- - c.txStatus = msg.TxStatus } func (c *Conn) rxRowDescription(msg *pgproto3.RowDescription) []FieldDescription { @@ -1548,13 +1531,13 @@ func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) { } func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) { - err = binary.Write(c.conn, binary.BigEndian, []int32{8, 80877103}) + err = binary.Write(c.BaseConn.NetConn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return } response := make([]byte, 1) - if _, err = io.ReadFull(c.conn, response); err != nil { + if _, err = io.ReadFull(c.BaseConn.NetConn, response); err != nil { return } @@ -1562,7 +1545,7 @@ func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) { return ErrTLSRefused } - c.conn = tls.Client(c.conn, tlsConfig) + c.BaseConn.NetConn = tls.Client(c.BaseConn.NetConn, tlsConfig) return nil } @@ -1576,7 +1559,7 @@ func (c *Conn) txPasswordMessage(password string) (err error) { buf = append(buf, 0) pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err = c.conn.Write(buf) + _, err = c.BaseConn.NetConn.Write(buf) return err } @@ -1591,7 +1574,7 @@ func (c *Conn) die(err error) { c.status = connStatusClosed c.causeOfDeath = err - c.conn.Close() + c.BaseConn.NetConn.Close() } func (c *Conn) lock() error { @@ -1626,8 +1609,8 @@ func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) { if data == nil { data = map[string]interface{}{} } - if c.pid != 0 { - data["pid"] = c.pid + if c.BaseConn.PID != 0 { + data["pid"] = c.BaseConn.PID } c.logger.Log(lvl, msg, data) @@ -1675,8 +1658,8 @@ func doCancel(c *Conn) error { buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) binary.BigEndian.PutUint32(buf[4:8], 80877102) - binary.BigEndian.PutUint32(buf[8:12], uint32(c.pid)) - binary.BigEndian.PutUint32(buf[12:16], uint32(c.secretKey)) + binary.BigEndian.PutUint32(buf[8:12], uint32(c.BaseConn.PID)) + binary.BigEndian.PutUint32(buf[12:16], uint32(c.BaseConn.SecretKey)) _, err = cancelConn.Write(buf) if err != nil { return err @@ -1696,7 +1679,7 @@ func doCancel(c *Conn) error { // is no way to be sure a query was canceled. See // https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 func (c *Conn) cancelQuery() { - if err := c.conn.SetDeadline(time.Now()); err != nil { + if err := c.BaseConn.NetConn.SetDeadline(time.Now()); err != nil { c.Close() // Close connection if unable to set deadline return } @@ -1781,7 +1764,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, buf = appendSync(buf) - n, err := c.conn.Write(buf) + n, err := c.BaseConn.NetConn.Write(buf) if err != nil && fatalWriteErr(n, err) { c.die(err) return "", err @@ -1916,7 +1899,7 @@ func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { c.mux.Unlock() select { case <-completeCh: - if err := c.conn.SetDeadline(time.Time{}); err != nil { + if err := c.BaseConn.NetConn.SetDeadline(time.Time{}); err != nil { c.Close() // Close connection if unable to disable deadline return err } diff --git a/conn_pool.go b/conn_pool.go index 270b992f..fda874ba 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -193,7 +193,7 @@ func (p *ConnPool) Release(conn *Conn) { panic("should never release when context is in progress") } - if conn.txStatus != 'I' { + if conn.BaseConn.TxStatus != 'I' { conn.Exec("rollback") } diff --git a/conn_test.go b/conn_test.go index c0419d90..b245af2e 100644 --- a/conn_test.go +++ b/conn_test.go @@ -52,7 +52,7 @@ func TestConnect(t *testing.T) { t.Fatalf("Unable to establish connection: %v", err) } - if _, present := conn.RuntimeParams["server_version"]; !present { + if _, present := conn.BaseConn.RuntimeParams["server_version"]; !present { t.Error("Runtime parameters not stored") } diff --git a/copy_from.go b/copy_from.go index 27e2fc9a..a4d4d91c 100644 --- a/copy_from.go +++ b/copy_from.go @@ -157,7 +157,7 @@ func (ct *copyFrom) run() (int, error) { sentCount += addedRows pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err = ct.conn.conn.Write(buf) + _, err = ct.conn.BaseConn.NetConn.Write(buf) if err != nil { panicked = false ct.conn.die(err) @@ -181,7 +181,7 @@ func (ct *copyFrom) run() (int, error) { buf = append(buf, copyDone) buf = pgio.AppendInt32(buf, 4) - _, err = ct.conn.conn.Write(buf) + _, err = ct.conn.BaseConn.NetConn.Write(buf) if err != nil { panicked = false ct.conn.die(err) @@ -256,7 +256,7 @@ func (ct *copyFrom) cancelCopyIn() error { buf = append(buf, 0) pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err := ct.conn.conn.Write(buf) + _, err := ct.conn.BaseConn.NetConn.Write(buf) if err != nil { ct.conn.die(err) return err @@ -304,7 +304,7 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) error { buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) - if _, err := c.conn.Write(buf); err != nil { + if _, err := c.BaseConn.NetConn.Write(buf); err != nil { return err } } @@ -313,7 +313,7 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) error { buf = append(buf, copyDone) buf = pgio.AppendInt32(buf, 4) - if _, err := c.conn.Write(buf); err != nil { + if _, err := c.BaseConn.NetConn.Write(buf); err != nil { return err } diff --git a/fastpath.go b/fastpath.go index f8af6190..1dffe1b0 100644 --- a/fastpath.go +++ b/fastpath.go @@ -72,7 +72,7 @@ func (f *fastpath) Call(oid pgtype.OID, args []fpArg) (res []byte, err error) { buf = pgio.AppendInt16(buf, 1) // response format code (binary) pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - if _, err := f.cn.conn.Write(buf); err != nil { + if _, err := f.cn.BaseConn.NetConn.Write(buf); err != nil { return nil, err } diff --git a/private_test.go b/private_test.go index df732a72..eef82ae8 100644 --- a/private_test.go +++ b/private_test.go @@ -3,5 +3,5 @@ package pgx // This file contains methods that expose internal pgx state to tests. func (c *Conn) TxStatus() byte { - return c.txStatus + return c.BaseConn.TxStatus } diff --git a/query.go b/query.go index ef99b1e5..c79540fa 100644 --- a/query.go +++ b/query.go @@ -413,7 +413,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, buf = appendSync(buf) - n, err := c.conn.Write(buf) + n, err := c.BaseConn.NetConn.Write(buf) if err != nil && fatalWriteErr(n, err) { rows.fatal(err) c.die(err) @@ -515,11 +515,11 @@ func (c *Conn) readUntilRowDescription() ([]FieldDescription, error) { } func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) { - if c.RuntimeParams["standard_conforming_strings"] != "on" { + if c.BaseConn.RuntimeParams["standard_conforming_strings"] != "on" { return errors.New("simple protocol queries must be run with standard_conforming_strings=on") } - if c.RuntimeParams["client_encoding"] != "UTF8" { + if c.BaseConn.RuntimeParams["client_encoding"] != "UTF8" { return errors.New("simple protocol queries must be run with client_encoding=UTF8") } diff --git a/replication.go b/replication.go index 452f9d3d..71d2847d 100644 --- a/replication.go +++ b/replication.go @@ -193,7 +193,7 @@ func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err = rc.c.conn.Write(buf) + _, err = rc.c.BaseConn.NetConn.Write(buf) if err != nil { rc.c.die(err) } @@ -300,7 +300,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl go func() { select { case <-ctx.Done(): - if err := rc.c.conn.SetDeadline(time.Now()); err != nil { + if err := rc.c.BaseConn.NetConn.SetDeadline(time.Now()); err != nil { rc.Close() // Close connection if unable to set deadline return } @@ -314,7 +314,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl var err error select { case err = <-rc.c.closedChan: - if err := rc.c.conn.SetDeadline(time.Time{}); err != nil { + if err := rc.c.BaseConn.NetConn.SetDeadline(time.Time{}); err != nil { rc.Close() // Close connection if unable to disable deadline return nil, err } diff --git a/tx.go b/tx.go index 0fb428fb..611d3f9f 100644 --- a/tx.go +++ b/tx.go @@ -260,7 +260,7 @@ func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) error { // Status returns the status of the transaction from the set of // pgx.TxStatus* constants. func (tx *Tx) Status() int8 { - if tx.status == TxStatusInProgress && tx.conn.txStatus == 'E' { + if tx.status == TxStatusInProgress && tx.conn.BaseConn.TxStatus == 'E' { return TxStatusInFailure } return tx.status diff --git a/v4_changes.md b/v4_changes.md new file mode 100644 index 00000000..f38e2b71 --- /dev/null +++ b/v4_changes.md @@ -0,0 +1,7 @@ +# V4 Changes + +`base.Conn` now contains core PostgreSQL connection functionality. + +## Incompatible Changes + +* `RuntimeParams` removed from `pgx.Conn` and added to `base.Conn`