Initial base.Connect extraction

This commit is contained in:
Jack Christensen 2018-11-12 18:05:26 -06:00
parent 06fb816b71
commit 65e69c5580
4 changed files with 270 additions and 149 deletions

View File

@ -1,11 +1,106 @@
package base
import (
"crypto/md5"
"crypto/tls"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"os"
"os/user"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/jackc/pgx/pgproto3"
)
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// DialFunc is a function that can be used to connect to a PostgreSQL server
type DialFunc func(network, addr string) (net.Conn, error)
// ErrTLSRefused occurs when the connection attempt requires TLS and the
// PostgreSQL server refuses to use TLS
var ErrTLSRefused = errors.New("server refused TLS connection")
type ConnConfig struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
Port uint16 // default: 5432
Database string
User string // default: OS user name
Password string
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
Dial DialFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
}
func (cc *ConnConfig) NetworkAddress() (network, address string) {
// If host is a valid path, then address is unix socket
if _, err := os.Stat(cc.Host); err == nil {
network = "unix"
address = cc.Host
if !strings.Contains(address, "/.s.PGSQL.") {
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
}
} else {
network = "tcp"
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
}
return network, address
}
func (cc *ConnConfig) assignDefaults() error {
if cc.User == "" {
user, err := user.Current()
if err != nil {
return err
}
cc.User = user.Username
}
if cc.Port == 0 {
cc.Port = 5432
}
if cc.Dial == nil {
defaultDialer := &net.Dialer{KeepAlive: 5 * time.Minute}
cc.Dial = defaultDialer.Dial
}
return nil
}
// Conn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type Conn struct {
NetConn net.Conn // the underlying TCP or unix domain socket connection
@ -14,6 +109,145 @@ type Conn struct {
RuntimeParams map[string]string // parameters that have been reported by the server
TxStatus byte
Frontend *pgproto3.Frontend
Config ConnConfig
}
func Connect(cc ConnConfig) (*Conn, error) {
err := cc.assignDefaults()
if err != nil {
return nil, err
}
conn := new(Conn)
conn.Config = cc
conn.NetConn, err = cc.Dial(cc.NetworkAddress())
if err != nil {
return nil, err
}
conn.RuntimeParams = make(map[string]string)
if cc.TLSConfig != nil {
if err := conn.startTLS(cc.TLSConfig); err != nil {
return nil, err
}
}
conn.Frontend, err = pgproto3.NewFrontend(conn.NetConn, conn.NetConn)
if err != nil {
return nil, err
}
startupMsg := pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
Parameters: make(map[string]string),
}
// Copy default run-time params
for k, v := range cc.RuntimeParams {
startupMsg.Parameters[k] = v
}
startupMsg.Parameters["user"] = cc.User
if cc.Database != "" {
startupMsg.Parameters["database"] = cc.Database
}
if _, err := conn.NetConn.Write(startupMsg.Encode(nil)); err != nil {
return nil, err
}
for {
msg, err := conn.ReceiveMessage()
if err != nil {
return nil, err
}
switch msg := msg.(type) {
case *pgproto3.BackendKeyData:
conn.PID = msg.ProcessID
conn.SecretKey = msg.SecretKey
case *pgproto3.Authentication:
if err = conn.rxAuthenticationX(msg); err != nil {
return nil, err
}
case *pgproto3.ReadyForQuery:
return conn, nil
case *pgproto3.ParameterStatus:
// handled by ReceiveMessage
case *pgproto3.ErrorResponse:
return nil, PgError{
Severity: msg.Severity,
Code: msg.Code,
Message: msg.Message,
Detail: msg.Detail,
Hint: msg.Hint,
Position: msg.Position,
InternalPosition: msg.InternalPosition,
InternalQuery: msg.InternalQuery,
Where: msg.Where,
SchemaName: msg.SchemaName,
TableName: msg.TableName,
ColumnName: msg.ColumnName,
DataTypeName: msg.DataTypeName,
ConstraintName: msg.ConstraintName,
File: msg.File,
Line: msg.Line,
Routine: msg.Routine,
}
default:
return nil, errors.New("unexpected message")
}
}
}
func (conn *Conn) startTLS(tlsConfig *tls.Config) (err error) {
err = binary.Write(conn.NetConn, binary.BigEndian, []int32{8, 80877103})
if err != nil {
return
}
response := make([]byte, 1)
if _, err = io.ReadFull(conn.NetConn, response); err != nil {
return
}
if response[0] != 'S' {
return ErrTLSRefused
}
conn.NetConn = tls.Client(conn.NetConn, tlsConfig)
return nil
}
func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
switch msg.Type {
case pgproto3.AuthTypeOk:
case pgproto3.AuthTypeCleartextPassword:
err = c.txPasswordMessage(c.Config.Password)
case pgproto3.AuthTypeMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(c.Config.Password+c.Config.User)+string(msg.Salt[:]))
err = c.txPasswordMessage(digestedPassword)
default:
err = errors.New("Received unknown authentication message")
}
return
}
func (conn *Conn) txPasswordMessage(password string) (err error) {
msg := &pgproto3.PasswordMessage{Password: password}
_, err = conn.NetConn.Write(msg.Encode(nil))
return err
}
func hexMD5(s string) string {
hash := md5.New()
io.WriteString(hash, s)
return hex.EncodeToString(hash.Sum(nil))
}
func (conn *Conn) ReceiveMessage() (pgproto3.BackendMessage, error) {

152
conn.go
View File

@ -13,8 +13,6 @@ import (
"net"
"net/url"
"os"
"os/user"
"path/filepath"
"reflect"
"regexp"
"strconv"
@ -92,27 +90,11 @@ type ConnConfig struct {
PreferSimpleProtocol bool
}
func (cc *ConnConfig) networkAddress() (network, address string) {
network = "tcp"
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
// See if host is a valid path, if yes connect with a socket
if _, err := os.Stat(cc.Host); err == nil {
// For backward compatibility accept socket file paths -- but directories are now preferred
network = "unix"
address = cc.Host
if !strings.Contains(address, "/.s.PGSQL.") {
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
}
}
return network, address
}
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
// Use ConnPool to manage access to multiple database connections from multiple
// goroutines.
type Conn struct {
BaseConn base.Conn
BaseConn *base.Conn
wbuf []byte
config ConnConfig // config used when establishing this connection
preparedStatements map[string]*PreparedStatement
@ -196,7 +178,7 @@ var ErrDeadConn = errors.New("conn is dead")
// ErrTLSRefused occurs when the connection attempt requires TLS and the
// PostgreSQL server refuses to use TLS
var ErrTLSRefused = errors.New("server refused TLS connection")
var ErrTLSRefused = base.ErrTLSRefused
// ErrConnBusy occurs when the connection is busy (for example, in the middle of
// reading query results) and another action is attempted.
@ -237,41 +219,17 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
}
c.logger = c.config.Logger
if c.config.User == "" {
user, err := user.Current()
if err != nil {
return nil, err
}
c.config.User = user.Username
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"User": c.config.User})
}
}
if c.config.Port == 0 {
c.config.Port = 5432
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"Port": c.config.Port})
}
}
c.onNotice = config.OnNotice
network, address := c.config.networkAddress()
if c.config.Dial == nil {
d := defaultDialer()
c.config.Dial = d.Dial
}
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address})
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Host})
}
err = c.connect(config, network, address, config.TLSConfig)
err = c.connect(config, config.TLSConfig)
if err != nil && config.UseFallbackTLS {
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err})
}
err = c.connect(config, network, address, config.FallbackTLSConfig)
err = c.connect(config, config.FallbackTLSConfig)
}
if err != nil {
@ -284,9 +242,19 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
return c, nil
}
func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
c.BaseConn = base.Conn{}
c.BaseConn.NetConn, err = c.config.Dial(network, address)
func (c *Conn) connect(config ConnConfig, tlsConfig *tls.Config) (err error) {
cc := base.ConnConfig{
Host: config.Host,
Port: config.Port,
Database: config.Database,
User: config.User,
Password: config.Password,
TLSConfig: tlsConfig,
Dial: base.DialFunc(config.Dial),
RuntimeParams: config.RuntimeParams,
}
c.BaseConn, err = base.Connect(cc)
if err != nil {
return err
}
@ -299,7 +267,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
}
}()
c.BaseConn.RuntimeParams = make(map[string]string)
c.preparedStatements = make(map[string]*PreparedStatement)
c.channels = make(map[string]struct{})
c.cancelQueryCompleted = make(chan struct{})
@ -312,81 +279,20 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.status = connStatusIdle
c.mux.Unlock()
if tlsConfig != nil {
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "starting TLS handshake", nil)
}
if err := c.startTLS(tlsConfig); err != nil {
return err
}
// Replication connections can't execute the queries to
// populate the c.PgTypes and c.pgsqlAfInet
if _, ok := config.RuntimeParams["replication"]; ok {
return nil
}
c.BaseConn.Frontend, err = pgproto3.NewFrontend(c.BaseConn.NetConn, c.BaseConn.NetConn)
if err != nil {
return err
}
startupMsg := pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
Parameters: make(map[string]string),
}
// Copy default run-time params
for k, v := range config.RuntimeParams {
startupMsg.Parameters[k] = v
}
startupMsg.Parameters["user"] = c.config.User
if c.config.Database != "" {
startupMsg.Parameters["database"] = c.config.Database
}
if _, err := c.BaseConn.NetConn.Write(startupMsg.Encode(nil)); err != nil {
return err
}
c.pendingReadyForQueryCount = 1
for {
msg, err := c.rxMsg()
if c.ConnInfo == minimalConnInfo {
err = c.initConnInfo()
if err != nil {
return err
}
switch msg := msg.(type) {
case *pgproto3.BackendKeyData:
c.BaseConn.PID = msg.ProcessID
c.BaseConn.SecretKey = msg.SecretKey
case *pgproto3.Authentication:
if err = c.rxAuthenticationX(msg); err != nil {
return err
}
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "connection established", nil)
}
// Replication connections can't execute the queries to
// populate the c.PgTypes and c.pgsqlAfInet
if _, ok := config.RuntimeParams["replication"]; ok {
return nil
}
if c.ConnInfo == minimalConnInfo {
err = c.initConnInfo()
if err != nil {
return err
}
}
return nil
default:
if err = c.processContextFreeMsg(msg); err != nil {
return err
}
}
}
return nil
}
func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) {
@ -1609,7 +1515,7 @@ func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) {
if data == nil {
data = map[string]interface{}{}
}
if c.BaseConn.PID != 0 {
if c.BaseConn != nil && c.BaseConn.PID != 0 {
data["pid"] = c.BaseConn.PID
}
@ -1641,8 +1547,8 @@ func quoteIdentifier(s string) string {
}
func doCancel(c *Conn) error {
network, address := c.config.networkAddress()
cancelConn, err := c.config.Dial(network, address)
network, address := c.BaseConn.Config.NetworkAddress()
cancelConn, err := c.BaseConn.Config.Dial(network, address)
if err != nil {
return err
}

View File

@ -6,6 +6,7 @@ import (
"reflect"
"time"
"github.com/jackc/pgx/base"
"github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgtype"
)
@ -78,32 +79,7 @@ func (fd FieldDescription) Type() reflect.Type {
}
}
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
type PgError = base.PgError
// Notice represents a notice response message reported by the PostgreSQL
// server. Be aware that this is distinct from LISTEN/NOTIFY notification.

5
v4.md
View File

@ -18,6 +18,11 @@ Potential Changes:
* Decouple connection pool from connections. Connection pool should be entirely replaceable.
* Decouple various logical layers of PostgreSQL connection such that an advanced user can choose what layer to work at and pgx still handles the lower level details. e.g Normal high level query level, PostgreSQL wire protocol message level, or wire byte level.
* Change prepared statement usage from using name as SQL text to specifically calling prepared statement (more like database/sql).
* Remove stdlib hack for RegisterDriverConfig now that database/sql supports better way
Minor Potential Changes:
* Change PgError error implementation to pointer method
## Changes