From b63370e5d5cb66a17c78871e3f63c14457894748 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 28 Dec 2018 12:16:50 -0600 Subject: [PATCH] Rename base.Conn to base.PgConn - pgx.Conn embeds base.PgConn privately - Add pgx.Conn.ParameterStatus --- base/conn.go | 52 ++++++++++++++++++++--------------------- batch.go | 2 +- conn.go | 62 ++++++++++++++++++++++++++----------------------- conn_pool.go | 2 +- conn_test.go | 2 +- copy_from.go | 10 ++++---- fastpath.go | 2 +- private_test.go | 2 +- query.go | 6 ++--- replication.go | 6 ++--- tx.go | 2 +- 11 files changed, 76 insertions(+), 72 deletions(-) diff --git a/base/conn.go b/base/conn.go index 26dd3f71..99278839 100644 --- a/base/conn.go +++ b/base/conn.go @@ -101,8 +101,8 @@ func (cc *ConnConfig) assignDefaults() error { return nil } -// Conn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. -type Conn struct { +// PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. +type PgConn struct { NetConn net.Conn // the underlying TCP or unix domain socket connection PID uint32 // backend pid SecretKey uint32 // key to use to send a cancel query message to the server @@ -113,29 +113,29 @@ type Conn struct { Config ConnConfig } -func Connect(cc ConnConfig) (*Conn, error) { +func Connect(cc ConnConfig) (*PgConn, error) { err := cc.assignDefaults() if err != nil { return nil, err } - conn := new(Conn) - conn.Config = cc + pgConn := new(PgConn) + pgConn.Config = cc - conn.NetConn, err = cc.Dial(cc.NetworkAddress()) + pgConn.NetConn, err = cc.Dial(cc.NetworkAddress()) if err != nil { return nil, err } - conn.RuntimeParams = make(map[string]string) + pgConn.RuntimeParams = make(map[string]string) if cc.TLSConfig != nil { - if err := conn.startTLS(cc.TLSConfig); err != nil { + if err := pgConn.startTLS(cc.TLSConfig); err != nil { return nil, err } } - conn.Frontend, err = pgproto3.NewFrontend(conn.NetConn, conn.NetConn) + pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.NetConn, pgConn.NetConn) if err != nil { return nil, err } @@ -155,26 +155,26 @@ func Connect(cc ConnConfig) (*Conn, error) { startupMsg.Parameters["database"] = cc.Database } - if _, err := conn.NetConn.Write(startupMsg.Encode(nil)); err != nil { + if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { return nil, err } for { - msg, err := conn.ReceiveMessage() + msg, err := pgConn.ReceiveMessage() if err != nil { return nil, err } switch msg := msg.(type) { case *pgproto3.BackendKeyData: - conn.PID = msg.ProcessID - conn.SecretKey = msg.SecretKey + pgConn.PID = msg.ProcessID + pgConn.SecretKey = msg.SecretKey case *pgproto3.Authentication: - if err = conn.rxAuthenticationX(msg); err != nil { + if err = pgConn.rxAuthenticationX(msg); err != nil { return nil, err } case *pgproto3.ReadyForQuery: - return conn, nil + return pgConn, nil case *pgproto3.ParameterStatus: // handled by ReceiveMessage case *pgproto3.ErrorResponse: @@ -203,14 +203,14 @@ func Connect(cc ConnConfig) (*Conn, error) { } } -func (conn *Conn) startTLS(tlsConfig *tls.Config) (err error) { - err = binary.Write(conn.NetConn, binary.BigEndian, []int32{8, 80877103}) +func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { + err = binary.Write(pgConn.NetConn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return } response := make([]byte, 1) - if _, err = io.ReadFull(conn.NetConn, response); err != nil { + if _, err = io.ReadFull(pgConn.NetConn, response); err != nil { return } @@ -218,12 +218,12 @@ func (conn *Conn) startTLS(tlsConfig *tls.Config) (err error) { return ErrTLSRefused } - conn.NetConn = tls.Client(conn.NetConn, tlsConfig) + pgConn.NetConn = tls.Client(pgConn.NetConn, tlsConfig) return nil } -func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { +func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { switch msg.Type { case pgproto3.AuthTypeOk: case pgproto3.AuthTypeCleartextPassword: @@ -238,9 +238,9 @@ func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { return } -func (conn *Conn) txPasswordMessage(password string) (err error) { +func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} - _, err = conn.NetConn.Write(msg.Encode(nil)) + _, err = pgConn.NetConn.Write(msg.Encode(nil)) return err } @@ -250,17 +250,17 @@ func hexMD5(s string) string { return hex.EncodeToString(hash.Sum(nil)) } -func (conn *Conn) ReceiveMessage() (pgproto3.BackendMessage, error) { - msg, err := conn.Frontend.Receive() +func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { + msg, err := pgConn.Frontend.Receive() if err != nil { return nil, err } switch msg := msg.(type) { case *pgproto3.ReadyForQuery: - conn.TxStatus = msg.TxStatus + pgConn.TxStatus = msg.TxStatus case *pgproto3.ParameterStatus: - conn.RuntimeParams[msg.Name] = msg.Value + pgConn.RuntimeParams[msg.Name] = msg.Value } return msg, nil diff --git a/batch.go b/batch.go index f7558030..b7d4c835 100644 --- a/batch.go +++ b/batch.go @@ -133,7 +133,7 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { b.conn.pendingReadyForQueryCount++ } - n, err := b.conn.BaseConn.NetConn.Write(buf) + n, err := b.conn.pgConn.NetConn.Write(buf) if err != nil { if fatalWriteErr(n, err) { b.conn.die(err) diff --git a/conn.go b/conn.go index 447af243..a202fa4b 100644 --- a/conn.go +++ b/conn.go @@ -94,7 +94,7 @@ type ConnConfig struct { // Use ConnPool to manage access to multiple database connections from multiple // goroutines. type Conn struct { - BaseConn *base.Conn + pgConn *base.PgConn wbuf []byte config ConnConfig // config used when establishing this connection preparedStatements map[string]*PreparedStatement @@ -255,13 +255,13 @@ func (c *Conn) connect(config ConnConfig, tlsConfig *tls.Config) (err error) { RuntimeParams: config.RuntimeParams, } - c.BaseConn, err = base.Connect(cc) + c.pgConn, err = base.Connect(cc) if err != nil { return err } defer func() { if c != nil && err != nil { - c.BaseConn.NetConn.Close() + c.pgConn.NetConn.Close() c.mux.Lock() c.status = connStatusClosed c.mux.Unlock() @@ -511,7 +511,7 @@ func (c *Conn) crateDBTypesQuery(err error) (*pgtype.ConnInfo, error) { // PID returns the backend PID for this connection. func (c *Conn) PID() uint32 { - return c.BaseConn.PID + return c.pgConn.PID } // LocalAddr returns the underlying connection's local address @@ -519,7 +519,7 @@ func (c *Conn) LocalAddr() (net.Addr, error) { if !c.IsAlive() { return nil, errors.New("connection not ready") } - return c.BaseConn.NetConn.LocalAddr(), nil + return c.pgConn.NetConn.LocalAddr(), nil } // Close closes a connection. It is safe to call Close on a already closed @@ -534,32 +534,32 @@ func (c *Conn) Close() (err error) { c.status = connStatusClosed defer func() { - c.BaseConn.NetConn.Close() + c.pgConn.NetConn.Close() c.causeOfDeath = errors.New("Closed") if c.shouldLog(LogLevelInfo) { c.log(LogLevelInfo, "closed connection", nil) } }() - err = c.BaseConn.NetConn.SetDeadline(time.Time{}) + err = c.pgConn.NetConn.SetDeadline(time.Time{}) if err != nil && c.shouldLog(LogLevelWarn) { c.log(LogLevelWarn, "failed to clear deadlines to send close message", map[string]interface{}{"err": err}) return err } - _, err = c.BaseConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) + _, err = c.pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil && c.shouldLog(LogLevelWarn) { c.log(LogLevelWarn, "failed to send terminate message", map[string]interface{}{"err": err}) return err } - err = c.BaseConn.NetConn.SetReadDeadline(time.Now().Add(5 * time.Second)) + err = c.pgConn.NetConn.SetReadDeadline(time.Now().Add(5 * time.Second)) if err != nil && c.shouldLog(LogLevelWarn) { c.log(LogLevelWarn, "failed to set read deadline to finish closing", map[string]interface{}{"err": err}) return err } - _, err = c.BaseConn.NetConn.Read(make([]byte, 1)) + _, err = c.pgConn.NetConn.Read(make([]byte, 1)) if err != io.EOF { return err } @@ -933,6 +933,10 @@ func configTLS(args configTLSArgs, cc *ConnConfig) error { return nil } +func (c *Conn) ParameterStatus(key string) string { + return c.pgConn.RuntimeParams[key] +} + // Prepare creates a prepared statement with name and sql. sql can contain placeholders // for bound parameters. These placeholders are referenced positional as $1, $2, etc. // @@ -997,7 +1001,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared buf = appendDescribe(buf, 'S', name) buf = appendSync(buf) - n, err := c.BaseConn.NetConn.Write(buf) + n, err := c.pgConn.NetConn.Write(buf) if err != nil { if fatalWriteErr(n, err) { c.die(err) @@ -1093,7 +1097,7 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { buf = append(buf, 'H') buf = pgio.AppendInt32(buf, 4) - _, err = c.BaseConn.NetConn.Write(buf) + _, err = c.pgConn.NetConn.Write(buf) if err != nil { c.die(err) return err @@ -1221,7 +1225,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { if len(args) == 0 { buf := appendQuery(c.wbuf, sql) - _, err := c.BaseConn.NetConn.Write(buf) + _, err := c.pgConn.NetConn.Write(buf) if err != nil { c.die(err) return err @@ -1260,7 +1264,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} buf = appendExecute(buf, "", 0) buf = appendSync(buf) - n, err := c.BaseConn.NetConn.Write(buf) + n, err := c.pgConn.NetConn.Write(buf) if err != nil { if fatalWriteErr(n, err) { c.die(err) @@ -1315,7 +1319,7 @@ func (c *Conn) rxMsg() (pgproto3.BackendMessage, error) { return nil, ErrDeadConn } - msg, err := c.BaseConn.ReceiveMessage() + msg, err := c.pgConn.ReceiveMessage() if err != nil { if netErr, ok := err.(net.Error); !(ok && netErr.Timeout()) { c.die(err) @@ -1438,13 +1442,13 @@ func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) { } func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) { - err = binary.Write(c.BaseConn.NetConn, binary.BigEndian, []int32{8, 80877103}) + err = binary.Write(c.pgConn.NetConn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return } response := make([]byte, 1) - if _, err = io.ReadFull(c.BaseConn.NetConn, response); err != nil { + if _, err = io.ReadFull(c.pgConn.NetConn, response); err != nil { return } @@ -1452,7 +1456,7 @@ func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) { return ErrTLSRefused } - c.BaseConn.NetConn = tls.Client(c.BaseConn.NetConn, tlsConfig) + c.pgConn.NetConn = tls.Client(c.pgConn.NetConn, tlsConfig) return nil } @@ -1466,7 +1470,7 @@ func (c *Conn) txPasswordMessage(password string) (err error) { buf = append(buf, 0) pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err = c.BaseConn.NetConn.Write(buf) + _, err = c.pgConn.NetConn.Write(buf) return err } @@ -1481,7 +1485,7 @@ func (c *Conn) die(err error) { c.status = connStatusClosed c.causeOfDeath = err - c.BaseConn.NetConn.Close() + c.pgConn.NetConn.Close() } func (c *Conn) lock() error { @@ -1516,8 +1520,8 @@ func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) { if data == nil { data = map[string]interface{}{} } - if c.BaseConn != nil && c.BaseConn.PID != 0 { - data["pid"] = c.BaseConn.PID + if c.pgConn != nil && c.pgConn.PID != 0 { + data["pid"] = c.pgConn.PID } c.logger.Log(lvl, msg, data) @@ -1548,8 +1552,8 @@ func quoteIdentifier(s string) string { } func doCancel(c *Conn) error { - network, address := c.BaseConn.Config.NetworkAddress() - cancelConn, err := c.BaseConn.Config.Dial(network, address) + network, address := c.pgConn.Config.NetworkAddress() + cancelConn, err := c.pgConn.Config.Dial(network, address) if err != nil { return err } @@ -1565,8 +1569,8 @@ func doCancel(c *Conn) error { buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) binary.BigEndian.PutUint32(buf[4:8], 80877102) - binary.BigEndian.PutUint32(buf[8:12], uint32(c.BaseConn.PID)) - binary.BigEndian.PutUint32(buf[12:16], uint32(c.BaseConn.SecretKey)) + binary.BigEndian.PutUint32(buf[8:12], uint32(c.pgConn.PID)) + binary.BigEndian.PutUint32(buf[12:16], uint32(c.pgConn.SecretKey)) _, err = cancelConn.Write(buf) if err != nil { return err @@ -1586,7 +1590,7 @@ func doCancel(c *Conn) error { // is no way to be sure a query was canceled. See // https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 func (c *Conn) cancelQuery() { - if err := c.BaseConn.NetConn.SetDeadline(time.Now()); err != nil { + if err := c.pgConn.NetConn.SetDeadline(time.Now()); err != nil { c.Close() // Close connection if unable to set deadline return } @@ -1673,7 +1677,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, buf = appendSync(buf) - n, err := c.BaseConn.NetConn.Write(buf) + n, err := c.pgConn.NetConn.Write(buf) c.lastStmtSent = true if err != nil && fatalWriteErr(n, err) { c.die(err) @@ -1811,7 +1815,7 @@ func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { c.mux.Unlock() select { case <-completeCh: - if err := c.BaseConn.NetConn.SetDeadline(time.Time{}); err != nil { + if err := c.pgConn.NetConn.SetDeadline(time.Time{}); err != nil { c.Close() // Close connection if unable to disable deadline return err } diff --git a/conn_pool.go b/conn_pool.go index 068a6886..b9ae1d07 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -193,7 +193,7 @@ func (p *ConnPool) Release(conn *Conn) { panic("should never release when context is in progress") } - if conn.BaseConn.TxStatus != 'I' { + if conn.pgConn.TxStatus != 'I' { conn.Exec("rollback") } diff --git a/conn_test.go b/conn_test.go index a9aaae21..db6cbc10 100644 --- a/conn_test.go +++ b/conn_test.go @@ -52,7 +52,7 @@ func TestConnect(t *testing.T) { t.Fatalf("Unable to establish connection: %v", err) } - if _, present := conn.BaseConn.RuntimeParams["server_version"]; !present { + if conn.ParameterStatus("server_version") == "" { t.Error("Runtime parameters not stored") } diff --git a/copy_from.go b/copy_from.go index 1e9a3c77..0166031a 100644 --- a/copy_from.go +++ b/copy_from.go @@ -157,7 +157,7 @@ func (ct *copyFrom) run() (int, error) { sentCount += addedRows pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err = ct.conn.BaseConn.NetConn.Write(buf) + _, err = ct.conn.pgConn.NetConn.Write(buf) if err != nil { panicked = false ct.conn.die(err) @@ -181,7 +181,7 @@ func (ct *copyFrom) run() (int, error) { buf = append(buf, copyDone) buf = pgio.AppendInt32(buf, 4) - _, err = ct.conn.BaseConn.NetConn.Write(buf) + _, err = ct.conn.pgConn.NetConn.Write(buf) if err != nil { panicked = false ct.conn.die(err) @@ -256,7 +256,7 @@ func (ct *copyFrom) cancelCopyIn() error { buf = append(buf, 0) pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err := ct.conn.BaseConn.NetConn.Write(buf) + _, err := ct.conn.pgConn.NetConn.Write(buf) if err != nil { ct.conn.die(err) return err @@ -304,7 +304,7 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) { buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) - if _, err := c.BaseConn.NetConn.Write(buf); err != nil { + if _, err := c.pgConn.NetConn.Write(buf); err != nil { return "", err } } @@ -313,7 +313,7 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) { buf = append(buf, copyDone) buf = pgio.AppendInt32(buf, 4) - if _, err := c.BaseConn.NetConn.Write(buf); err != nil { + if _, err := c.pgConn.NetConn.Write(buf); err != nil { return "", err } diff --git a/fastpath.go b/fastpath.go index 1dffe1b0..63f8c3c5 100644 --- a/fastpath.go +++ b/fastpath.go @@ -72,7 +72,7 @@ func (f *fastpath) Call(oid pgtype.OID, args []fpArg) (res []byte, err error) { buf = pgio.AppendInt16(buf, 1) // response format code (binary) pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - if _, err := f.cn.BaseConn.NetConn.Write(buf); err != nil { + if _, err := f.cn.pgConn.NetConn.Write(buf); err != nil { return nil, err } diff --git a/private_test.go b/private_test.go index eef82ae8..dd76b43e 100644 --- a/private_test.go +++ b/private_test.go @@ -3,5 +3,5 @@ package pgx // This file contains methods that expose internal pgx state to tests. func (c *Conn) TxStatus() byte { - return c.BaseConn.TxStatus + return c.pgConn.TxStatus } diff --git a/query.go b/query.go index d807e22c..b3bb56e3 100644 --- a/query.go +++ b/query.go @@ -415,7 +415,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, buf = appendSync(buf) - n, err := c.BaseConn.NetConn.Write(buf) + n, err := c.pgConn.NetConn.Write(buf) c.lastStmtSent = true if err != nil && fatalWriteErr(n, err) { rows.fatal(err) @@ -519,11 +519,11 @@ func (c *Conn) readUntilRowDescription() ([]FieldDescription, error) { } func (c *Conn) sanitizeAndSendSimpleQuery(sql string, args ...interface{}) (err error) { - if c.BaseConn.RuntimeParams["standard_conforming_strings"] != "on" { + if c.pgConn.RuntimeParams["standard_conforming_strings"] != "on" { return errors.New("simple protocol queries must be run with standard_conforming_strings=on") } - if c.BaseConn.RuntimeParams["client_encoding"] != "UTF8" { + if c.pgConn.RuntimeParams["client_encoding"] != "UTF8" { return errors.New("simple protocol queries must be run with client_encoding=UTF8") } diff --git a/replication.go b/replication.go index 71d2847d..782051fc 100644 --- a/replication.go +++ b/replication.go @@ -193,7 +193,7 @@ func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err = rc.c.BaseConn.NetConn.Write(buf) + _, err = rc.c.pgConn.NetConn.Write(buf) if err != nil { rc.c.die(err) } @@ -300,7 +300,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl go func() { select { case <-ctx.Done(): - if err := rc.c.BaseConn.NetConn.SetDeadline(time.Now()); err != nil { + if err := rc.c.pgConn.NetConn.SetDeadline(time.Now()); err != nil { rc.Close() // Close connection if unable to set deadline return } @@ -314,7 +314,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl var err error select { case err = <-rc.c.closedChan: - if err := rc.c.BaseConn.NetConn.SetDeadline(time.Time{}); err != nil { + if err := rc.c.pgConn.NetConn.SetDeadline(time.Time{}); err != nil { rc.Close() // Close connection if unable to disable deadline return nil, err } diff --git a/tx.go b/tx.go index 123f82b9..30b0ede3 100644 --- a/tx.go +++ b/tx.go @@ -260,7 +260,7 @@ func (tx *Tx) CopyToWriter(w io.Writer, sql string, args ...interface{}) (comman // Status returns the status of the transaction from the set of // pgx.TxStatus* constants. func (tx *Tx) Status() int8 { - if tx.status == TxStatusInProgress && tx.conn.BaseConn.TxStatus == 'E' { + if tx.status == TxStatusInProgress && tx.conn.pgConn.TxStatus == 'E' { return TxStatusInFailure } return tx.status