From 1ff8024df992303d221795fb400b98853abbd7ef Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Mon, 31 Dec 2018 18:00:08 -0600 Subject: [PATCH] Access underlying net.Conn via method Also remove some dead code. --- batch.go | 2 +- conn.go | 55 ++++++++-------------------------------------- copy_from.go | 10 ++++----- fastpath.go | 2 +- pgconn/pgconn.go | 57 ++++++++++++++++++++++++++---------------------- query.go | 2 +- replication.go | 6 ++--- 7 files changed, 51 insertions(+), 83 deletions(-) diff --git a/batch.go b/batch.go index b7d4c835..75252ef5 100644 --- a/batch.go +++ b/batch.go @@ -133,7 +133,7 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { b.conn.pendingReadyForQueryCount++ } - n, err := b.conn.pgConn.NetConn.Write(buf) + n, err := b.conn.pgConn.Conn().Write(buf) if err != nil { if fatalWriteErr(n, err) { b.conn.die(err) diff --git a/conn.go b/conn.go index 06100004..4197d9a2 100644 --- a/conn.go +++ b/conn.go @@ -2,9 +2,6 @@ package pgx import ( "context" - "crypto/tls" - "encoding/binary" - "io" "net" "reflect" "strconv" @@ -478,7 +475,7 @@ func (c *Conn) LocalAddr() (net.Addr, error) { if !c.IsAlive() { return nil, errors.New("connection not ready") } - return c.pgConn.NetConn.LocalAddr(), nil + return c.pgConn.Conn().LocalAddr(), nil } // Close closes a connection. It is safe to call Close on a already closed @@ -570,7 +567,7 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared buf = appendDescribe(buf, 'S', name) buf = appendSync(buf) - n, err := c.pgConn.NetConn.Write(buf) + n, err := c.pgConn.Conn().Write(buf) if err != nil { if fatalWriteErr(n, err) { c.die(err) @@ -666,7 +663,7 @@ func (c *Conn) deallocateContext(ctx context.Context, name string) (err error) { buf = append(buf, 'H') buf = pgio.AppendInt32(buf, 4) - _, err = c.pgConn.NetConn.Write(buf) + _, err = c.pgConn.Conn().Write(buf) if err != nil { c.die(err) return err @@ -794,7 +791,7 @@ func (c *Conn) sendSimpleQuery(sql string, args ...interface{}) error { if len(args) == 0 { buf := appendQuery(c.wbuf, sql) - _, err := c.pgConn.NetConn.Write(buf) + _, err := c.pgConn.Conn().Write(buf) if err != nil { c.die(err) return err @@ -833,7 +830,7 @@ func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{} buf = appendExecute(buf, "", 0) buf = appendSync(buf) - n, err := c.pgConn.NetConn.Write(buf) + n, err := c.pgConn.Conn().Write(buf) if err != nil { if fatalWriteErr(n, err) { c.die(err) @@ -989,40 +986,6 @@ func (c *Conn) rxNotificationResponse(msg *pgproto3.NotificationResponse) { c.notifications = append(c.notifications, n) } -func (c *Conn) startTLS(tlsConfig *tls.Config) (err error) { - err = binary.Write(c.pgConn.NetConn, binary.BigEndian, []int32{8, 80877103}) - if err != nil { - return - } - - response := make([]byte, 1) - if _, err = io.ReadFull(c.pgConn.NetConn, response); err != nil { - return - } - - if response[0] != 'S' { - return ErrTLSRefused - } - - c.pgConn.NetConn = tls.Client(c.pgConn.NetConn, tlsConfig) - - return nil -} - -func (c *Conn) txPasswordMessage(password string) (err error) { - buf := c.wbuf - buf = append(buf, 'p') - sp := len(buf) - buf = pgio.AppendInt32(buf, -1) - buf = append(buf, password...) - buf = append(buf, 0) - pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - - _, err = c.pgConn.NetConn.Write(buf) - - return err -} - func (c *Conn) die(err error) { c.mux.Lock() defer c.mux.Unlock() @@ -1033,7 +996,7 @@ func (c *Conn) die(err error) { c.status = connStatusClosed c.causeOfDeath = err - c.pgConn.NetConn.Close() + c.pgConn.Conn().Close() } func (c *Conn) lock() error { @@ -1111,7 +1074,7 @@ func doCancel(c *Conn) error { // is no way to be sure a query was canceled. See // https://www.postgresql.org/docs/current/static/protocol-flow.html#AEN112861 func (c *Conn) cancelQuery() { - if err := c.pgConn.NetConn.SetDeadline(time.Now()); err != nil { + if err := c.pgConn.Conn().SetDeadline(time.Now()); err != nil { c.Close() // Close connection if unable to set deadline return } @@ -1198,7 +1161,7 @@ func (c *Conn) execEx(ctx context.Context, sql string, options *QueryExOptions, buf = appendSync(buf) - n, err := c.pgConn.NetConn.Write(buf) + n, err := c.pgConn.Conn().Write(buf) c.lastStmtSent = true if err != nil && fatalWriteErr(n, err) { c.die(err) @@ -1336,7 +1299,7 @@ func (c *Conn) waitForPreviousCancelQuery(ctx context.Context) error { c.mux.Unlock() select { case <-completeCh: - if err := c.pgConn.NetConn.SetDeadline(time.Time{}); err != nil { + if err := c.pgConn.Conn().SetDeadline(time.Time{}); err != nil { c.Close() // Close connection if unable to disable deadline return err } diff --git a/copy_from.go b/copy_from.go index 0166031a..0bdfb315 100644 --- a/copy_from.go +++ b/copy_from.go @@ -157,7 +157,7 @@ func (ct *copyFrom) run() (int, error) { sentCount += addedRows pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err = ct.conn.pgConn.NetConn.Write(buf) + _, err = ct.conn.pgConn.Conn().Write(buf) if err != nil { panicked = false ct.conn.die(err) @@ -181,7 +181,7 @@ func (ct *copyFrom) run() (int, error) { buf = append(buf, copyDone) buf = pgio.AppendInt32(buf, 4) - _, err = ct.conn.pgConn.NetConn.Write(buf) + _, err = ct.conn.pgConn.Conn().Write(buf) if err != nil { panicked = false ct.conn.die(err) @@ -256,7 +256,7 @@ func (ct *copyFrom) cancelCopyIn() error { buf = append(buf, 0) pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err := ct.conn.pgConn.NetConn.Write(buf) + _, err := ct.conn.pgConn.Conn().Write(buf) if err != nil { ct.conn.die(err) return err @@ -304,7 +304,7 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) { buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) - if _, err := c.pgConn.NetConn.Write(buf); err != nil { + if _, err := c.pgConn.Conn().Write(buf); err != nil { return "", err } } @@ -313,7 +313,7 @@ func (c *Conn) CopyFromReader(r io.Reader, sql string) (CommandTag, error) { buf = append(buf, copyDone) buf = pgio.AppendInt32(buf, 4) - if _, err := c.pgConn.NetConn.Write(buf); err != nil { + if _, err := c.pgConn.Conn().Write(buf); err != nil { return "", err } diff --git a/fastpath.go b/fastpath.go index 63f8c3c5..f9764c53 100644 --- a/fastpath.go +++ b/fastpath.go @@ -72,7 +72,7 @@ func (f *fastpath) Call(oid pgtype.OID, args []fpArg) (res []byte, err error) { buf = pgio.AppendInt16(buf, 1) // response format code (binary) pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - if _, err := f.cn.pgConn.NetConn.Write(buf); err != nil { + if _, err := f.cn.pgConn.Conn().Write(buf); err != nil { return nil, err } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index fef113e0..776141f9 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -58,7 +58,7 @@ var ErrTLSRefused = errors.New("server refused TLS connection") // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { - NetConn net.Conn // the underlying TCP or unix domain socket connection + conn net.Conn // the underlying TCP or unix domain socket connection PID uint32 // backend pid SecretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server @@ -132,7 +132,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) - pgConn.NetConn, err = config.DialFunc(ctx, network, address) + pgConn.conn, err = config.DialFunc(ctx, network, address) if err != nil { return nil, err } @@ -141,12 +141,12 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig if config.TLSConfig != nil { if err := pgConn.startTLS(config.TLSConfig); err != nil { - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, err } } - pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.NetConn, pgConn.NetConn) + pgConn.Frontend, err = pgproto3.NewFrontend(pgConn.conn, pgConn.conn) if err != nil { return nil, err } @@ -166,8 +166,8 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig startupMsg.Parameters["database"] = config.Database } - if _, err := pgConn.NetConn.Write(startupMsg.Encode(nil)); err != nil { - pgConn.NetConn.Close() + if _, err := pgConn.conn.Write(startupMsg.Encode(nil)); err != nil { + pgConn.conn.Close() return nil, err } @@ -183,14 +183,14 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig pgConn.SecretKey = msg.SecretKey case *pgproto3.Authentication: if err = pgConn.rxAuthenticationX(msg); err != nil { - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, err } case *pgproto3.ReadyForQuery: if config.AfterConnectFunc != nil { err := config.AfterConnectFunc(ctx, pgConn) if err != nil { - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, fmt.Errorf("AfterConnectFunc: %v", err) } } @@ -198,7 +198,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig case *pgproto3.ParameterStatus: // handled by ReceiveMessage case *pgproto3.ErrorResponse: - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, &PgError{ Severity: msg.Severity, Code: msg.Code, @@ -219,20 +219,20 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig Routine: msg.Routine, } default: - pgConn.NetConn.Close() + pgConn.conn.Close() return nil, errors.New("unexpected message") } } } func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { - err = binary.Write(pgConn.NetConn, binary.BigEndian, []int32{8, 80877103}) + err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return } response := make([]byte, 1) - if _, err = io.ReadFull(pgConn.NetConn, response); err != nil { + if _, err = io.ReadFull(pgConn.conn, response); err != nil { return } @@ -240,7 +240,7 @@ func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { return ErrTLSRefused } - pgConn.NetConn = tls.Client(pgConn.NetConn, tlsConfig) + pgConn.conn = tls.Client(pgConn.conn, tlsConfig) return nil } @@ -262,7 +262,7 @@ func (c *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} - _, err = pgConn.NetConn.Write(msg.Encode(nil)) + _, err = pgConn.conn.Write(msg.Encode(nil)) return err } @@ -299,6 +299,11 @@ func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { return msg, nil } +// Conn returns the underlying net.Conn. +func (pgConn *PgConn) Conn() net.Conn { + return pgConn.conn +} + // Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // underlying net.Conn.Close() will always be called regardless of any other errors. @@ -308,22 +313,22 @@ func (pgConn *PgConn) Close(ctx context.Context) error { } pgConn.closed = true - defer pgConn.NetConn.Close() + defer pgConn.conn.Close() - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContext() - _, err := pgConn.NetConn.Write([]byte{'X', 0, 0, 0, 4}) + _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { return preferContextOverNetTimeoutError(ctx, err) } - _, err = pgConn.NetConn.Read(make([]byte, 1)) + _, err = pgConn.conn.Read(make([]byte, 1)) if err != io.EOF { return preferContextOverNetTimeoutError(ctx, err) } - return pgConn.NetConn.Close() + return pgConn.conn.Close() } // ParameterStatus returns the value of a parameter reported by the server (e.g. @@ -380,7 +385,7 @@ type PgResultReader struct { // consumed it returns nil. If an error occurs it will be reported on the // returned PgResultReader. func (pgConn *PgConn) GetResult(ctx context.Context) *PgResultReader { - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) for pgConn.pendingReadyForQueryCount > 0 { msg, err := pgConn.ReceiveMessage() @@ -491,14 +496,14 @@ func (rr *PgResultReader) close() { func (pgConn *PgConn) Flush(ctx context.Context) error { defer pgConn.resetBatch() - cleanup := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanup := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanup() - n, err := pgConn.NetConn.Write(pgConn.batchBuf) + n, err := pgConn.conn.Write(pgConn.batchBuf) if err != nil { if n > 0 { // Close connection because cannot recover from partially sent message. - pgConn.NetConn.Close() + pgConn.conn.Close() pgConn.closed = true } return preferContextOverNetTimeoutError(ctx, err) @@ -563,14 +568,14 @@ func (pgConn *PgConn) RecoverFromTimeout(ctx context.Context) bool { pgConn.resetBatch() // Clear any existing timeout - pgConn.NetConn.SetDeadline(time.Time{}) + pgConn.conn.SetDeadline(time.Time{}) // Try to cancel any in-progress requests for i := 0; i < int(pgConn.pendingReadyForQueryCount); i++ { pgConn.CancelRequest(ctx) } - cleanupContext := contextDoneToConnDeadline(ctx, pgConn.NetConn) + cleanupContext := contextDoneToConnDeadline(ctx, pgConn.conn) defer cleanupContext() for pgConn.pendingReadyForQueryCount > 0 { @@ -683,7 +688,7 @@ func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. - serverAddr := pgConn.NetConn.RemoteAddr() + serverAddr := pgConn.conn.RemoteAddr() cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err diff --git a/query.go b/query.go index 0f1152c1..1914b593 100644 --- a/query.go +++ b/query.go @@ -415,7 +415,7 @@ func (c *Conn) QueryEx(ctx context.Context, sql string, options *QueryExOptions, buf = appendSync(buf) - n, err := c.pgConn.NetConn.Write(buf) + n, err := c.pgConn.Conn().Write(buf) c.lastStmtSent = true if err != nil && fatalWriteErr(n, err) { rows.fatal(err) diff --git a/replication.go b/replication.go index 25d21b48..cf4058ab 100644 --- a/replication.go +++ b/replication.go @@ -193,7 +193,7 @@ func (rc *ReplicationConn) SendStandbyStatus(k *StandbyStatus) (err error) { pgio.SetInt32(buf[sp:], int32(len(buf[sp:]))) - _, err = rc.c.pgConn.NetConn.Write(buf) + _, err = rc.c.pgConn.Conn().Write(buf) if err != nil { rc.c.die(err) } @@ -300,7 +300,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl go func() { select { case <-ctx.Done(): - if err := rc.c.pgConn.NetConn.SetDeadline(time.Now()); err != nil { + if err := rc.c.pgConn.Conn().SetDeadline(time.Now()); err != nil { rc.Close() // Close connection if unable to set deadline return } @@ -314,7 +314,7 @@ func (rc *ReplicationConn) WaitForReplicationMessage(ctx context.Context) (*Repl var err error select { case err = <-rc.c.closedChan: - if err := rc.c.pgConn.NetConn.SetDeadline(time.Time{}); err != nil { + if err := rc.c.pgConn.Conn().SetDeadline(time.Time{}); err != nil { rc.Close() // Close connection if unable to disable deadline return nil, err }