Initial base.Connect extraction

This commit is contained in:
Jack Christensen 2018-11-12 18:05:26 -06:00
parent 06fb816b71
commit 65e69c5580
4 changed files with 270 additions and 149 deletions

View File

@ -1,11 +1,106 @@
package base package base
import ( import (
"crypto/md5"
"crypto/tls"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"net" "net"
"os"
"os/user"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/jackc/pgx/pgproto3" "github.com/jackc/pgx/pgproto3"
) )
// PgError represents an error reported by the PostgreSQL server. See
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// DialFunc is a function that can be used to connect to a PostgreSQL server
type DialFunc func(network, addr string) (net.Conn, error)
// ErrTLSRefused occurs when the connection attempt requires TLS and the
// PostgreSQL server refuses to use TLS
var ErrTLSRefused = errors.New("server refused TLS connection")
type ConnConfig struct {
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
Port uint16 // default: 5432
Database string
User string // default: OS user name
Password string
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
Dial DialFunc
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
}
func (cc *ConnConfig) NetworkAddress() (network, address string) {
// If host is a valid path, then address is unix socket
if _, err := os.Stat(cc.Host); err == nil {
network = "unix"
address = cc.Host
if !strings.Contains(address, "/.s.PGSQL.") {
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
}
} else {
network = "tcp"
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
}
return network, address
}
func (cc *ConnConfig) assignDefaults() error {
if cc.User == "" {
user, err := user.Current()
if err != nil {
return err
}
cc.User = user.Username
}
if cc.Port == 0 {
cc.Port = 5432
}
if cc.Dial == nil {
defaultDialer := &net.Dialer{KeepAlive: 5 * time.Minute}
cc.Dial = defaultDialer.Dial
}
return nil
}
// Conn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. // Conn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
type Conn struct { type Conn struct {
NetConn net.Conn // the underlying TCP or unix domain socket connection NetConn net.Conn // the underlying TCP or unix domain socket connection
@ -14,6 +109,145 @@ type Conn struct {
RuntimeParams map[string]string // parameters that have been reported by the server RuntimeParams map[string]string // parameters that have been reported by the server
TxStatus byte TxStatus byte
Frontend *pgproto3.Frontend Frontend *pgproto3.Frontend
Config ConnConfig
}
func Connect(cc ConnConfig) (*Conn, error) {
err := cc.assignDefaults()
if err != nil {
return nil, err
}
conn := new(Conn)
conn.Config = cc
conn.NetConn, err = cc.Dial(cc.NetworkAddress())
if err != nil {
return nil, err
}
conn.RuntimeParams = make(map[string]string)
if cc.TLSConfig != nil {
if err := conn.startTLS(cc.TLSConfig); err != nil {
return nil, err
}
}
conn.Frontend, err = pgproto3.NewFrontend(conn.NetConn, conn.NetConn)
if err != nil {
return nil, err
}
startupMsg := pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
Parameters: make(map[string]string),
}
// Copy default run-time params
for k, v := range cc.RuntimeParams {
startupMsg.Parameters[k] = v
}
startupMsg.Parameters["user"] = cc.User
if cc.Database != "" {
startupMsg.Parameters["database"] = cc.Database
}
if _, err := conn.NetConn.Write(startupMsg.Encode(nil)); err != nil {
return nil, err
}
for {
msg, err := conn.ReceiveMessage()
if err != nil {
return nil, err
}
switch msg := msg.(type) {
case *pgproto3.BackendKeyData:
conn.PID = msg.ProcessID
conn.SecretKey = msg.SecretKey
case *pgproto3.Authentication:
if err = conn.rxAuthenticationX(msg); err != nil {
return nil, err
}
case *pgproto3.ReadyForQuery:
return conn, nil
case *pgproto3.ParameterStatus:
// handled by ReceiveMessage
case *pgproto3.ErrorResponse:
return nil, PgError{
Severity: msg.Severity,
Code: msg.Code,
Message: msg.Message,
Detail: msg.Detail,
Hint: msg.Hint,
Position: msg.Position,
InternalPosition: msg.InternalPosition,
InternalQuery: msg.InternalQuery,
Where: msg.Where,
SchemaName: msg.SchemaName,
TableName: msg.TableName,
ColumnName: msg.ColumnName,
DataTypeName: msg.DataTypeName,
ConstraintName: msg.ConstraintName,
File: msg.File,
Line: msg.Line,
Routine: msg.Routine,
}
default:
return nil, errors.New("unexpected message")
}
}
}
func (conn *Conn) startTLS(tlsConfig *tls.Config) (err error) {
err = binary.Write(conn.NetConn, binary.BigEndian, []int32{8, 80877103})
if err != nil {
return
}
response := make([]byte, 1)
if _, err = io.ReadFull(conn.NetConn, response); err != nil {
return
}
if response[0] != 'S' {
return ErrTLSRefused
}
conn.NetConn = tls.Client(conn.NetConn, tlsConfig)
return nil
}
func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
switch msg.Type {
case pgproto3.AuthTypeOk:
case pgproto3.AuthTypeCleartextPassword:
err = c.txPasswordMessage(c.Config.Password)
case pgproto3.AuthTypeMD5Password:
digestedPassword := "md5" + hexMD5(hexMD5(c.Config.Password+c.Config.User)+string(msg.Salt[:]))
err = c.txPasswordMessage(digestedPassword)
default:
err = errors.New("Received unknown authentication message")
}
return
}
func (conn *Conn) txPasswordMessage(password string) (err error) {
msg := &pgproto3.PasswordMessage{Password: password}
_, err = conn.NetConn.Write(msg.Encode(nil))
return err
}
func hexMD5(s string) string {
hash := md5.New()
io.WriteString(hash, s)
return hex.EncodeToString(hash.Sum(nil))
} }
func (conn *Conn) ReceiveMessage() (pgproto3.BackendMessage, error) { func (conn *Conn) ReceiveMessage() (pgproto3.BackendMessage, error) {

152
conn.go
View File

@ -13,8 +13,6 @@ import (
"net" "net"
"net/url" "net/url"
"os" "os"
"os/user"
"path/filepath"
"reflect" "reflect"
"regexp" "regexp"
"strconv" "strconv"
@ -92,27 +90,11 @@ type ConnConfig struct {
PreferSimpleProtocol bool PreferSimpleProtocol bool
} }
func (cc *ConnConfig) networkAddress() (network, address string) {
network = "tcp"
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
// See if host is a valid path, if yes connect with a socket
if _, err := os.Stat(cc.Host); err == nil {
// For backward compatibility accept socket file paths -- but directories are now preferred
network = "unix"
address = cc.Host
if !strings.Contains(address, "/.s.PGSQL.") {
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
}
}
return network, address
}
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
// Use ConnPool to manage access to multiple database connections from multiple // Use ConnPool to manage access to multiple database connections from multiple
// goroutines. // goroutines.
type Conn struct { type Conn struct {
BaseConn base.Conn BaseConn *base.Conn
wbuf []byte wbuf []byte
config ConnConfig // config used when establishing this connection config ConnConfig // config used when establishing this connection
preparedStatements map[string]*PreparedStatement preparedStatements map[string]*PreparedStatement
@ -196,7 +178,7 @@ var ErrDeadConn = errors.New("conn is dead")
// ErrTLSRefused occurs when the connection attempt requires TLS and the // ErrTLSRefused occurs when the connection attempt requires TLS and the
// PostgreSQL server refuses to use TLS // PostgreSQL server refuses to use TLS
var ErrTLSRefused = errors.New("server refused TLS connection") var ErrTLSRefused = base.ErrTLSRefused
// ErrConnBusy occurs when the connection is busy (for example, in the middle of // ErrConnBusy occurs when the connection is busy (for example, in the middle of
// reading query results) and another action is attempted. // reading query results) and another action is attempted.
@ -237,41 +219,17 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
} }
c.logger = c.config.Logger c.logger = c.config.Logger
if c.config.User == "" {
user, err := user.Current()
if err != nil {
return nil, err
}
c.config.User = user.Username
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"User": c.config.User})
}
}
if c.config.Port == 0 {
c.config.Port = 5432
if c.shouldLog(LogLevelDebug) {
c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"Port": c.config.Port})
}
}
c.onNotice = config.OnNotice c.onNotice = config.OnNotice
network, address := c.config.networkAddress()
if c.config.Dial == nil {
d := defaultDialer()
c.config.Dial = d.Dial
}
if c.shouldLog(LogLevelInfo) { if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Host})
} }
err = c.connect(config, network, address, config.TLSConfig) err = c.connect(config, config.TLSConfig)
if err != nil && config.UseFallbackTLS { if err != nil && config.UseFallbackTLS {
if c.shouldLog(LogLevelInfo) { if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err})
} }
err = c.connect(config, network, address, config.FallbackTLSConfig) err = c.connect(config, config.FallbackTLSConfig)
} }
if err != nil { if err != nil {
@ -284,9 +242,19 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
return c, nil return c, nil
} }
func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { func (c *Conn) connect(config ConnConfig, tlsConfig *tls.Config) (err error) {
c.BaseConn = base.Conn{} cc := base.ConnConfig{
c.BaseConn.NetConn, err = c.config.Dial(network, address) Host: config.Host,
Port: config.Port,
Database: config.Database,
User: config.User,
Password: config.Password,
TLSConfig: tlsConfig,
Dial: base.DialFunc(config.Dial),
RuntimeParams: config.RuntimeParams,
}
c.BaseConn, err = base.Connect(cc)
if err != nil { if err != nil {
return err return err
} }
@ -299,7 +267,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
} }
}() }()
c.BaseConn.RuntimeParams = make(map[string]string)
c.preparedStatements = make(map[string]*PreparedStatement) c.preparedStatements = make(map[string]*PreparedStatement)
c.channels = make(map[string]struct{}) c.channels = make(map[string]struct{})
c.cancelQueryCompleted = make(chan struct{}) c.cancelQueryCompleted = make(chan struct{})
@ -312,81 +279,20 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
c.status = connStatusIdle c.status = connStatusIdle
c.mux.Unlock() c.mux.Unlock()
if tlsConfig != nil { // Replication connections can't execute the queries to
if c.shouldLog(LogLevelDebug) { // populate the c.PgTypes and c.pgsqlAfInet
c.log(LogLevelDebug, "starting TLS handshake", nil) if _, ok := config.RuntimeParams["replication"]; ok {
} return nil
if err := c.startTLS(tlsConfig); err != nil {
return err
}
} }
c.BaseConn.Frontend, err = pgproto3.NewFrontend(c.BaseConn.NetConn, c.BaseConn.NetConn) if c.ConnInfo == minimalConnInfo {
if err != nil { err = c.initConnInfo()
return err
}
startupMsg := pgproto3.StartupMessage{
ProtocolVersion: pgproto3.ProtocolVersionNumber,
Parameters: make(map[string]string),
}
// Copy default run-time params
for k, v := range config.RuntimeParams {
startupMsg.Parameters[k] = v
}
startupMsg.Parameters["user"] = c.config.User
if c.config.Database != "" {
startupMsg.Parameters["database"] = c.config.Database
}
if _, err := c.BaseConn.NetConn.Write(startupMsg.Encode(nil)); err != nil {
return err
}
c.pendingReadyForQueryCount = 1
for {
msg, err := c.rxMsg()
if err != nil { if err != nil {
return err return err
} }
switch msg := msg.(type) {
case *pgproto3.BackendKeyData:
c.BaseConn.PID = msg.ProcessID
c.BaseConn.SecretKey = msg.SecretKey
case *pgproto3.Authentication:
if err = c.rxAuthenticationX(msg); err != nil {
return err
}
case *pgproto3.ReadyForQuery:
c.rxReadyForQuery(msg)
if c.shouldLog(LogLevelInfo) {
c.log(LogLevelInfo, "connection established", nil)
}
// Replication connections can't execute the queries to
// populate the c.PgTypes and c.pgsqlAfInet
if _, ok := config.RuntimeParams["replication"]; ok {
return nil
}
if c.ConnInfo == minimalConnInfo {
err = c.initConnInfo()
if err != nil {
return err
}
}
return nil
default:
if err = c.processContextFreeMsg(msg); err != nil {
return err
}
}
} }
return nil
} }
func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) { func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) {
@ -1609,7 +1515,7 @@ func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) {
if data == nil { if data == nil {
data = map[string]interface{}{} data = map[string]interface{}{}
} }
if c.BaseConn.PID != 0 { if c.BaseConn != nil && c.BaseConn.PID != 0 {
data["pid"] = c.BaseConn.PID data["pid"] = c.BaseConn.PID
} }
@ -1641,8 +1547,8 @@ func quoteIdentifier(s string) string {
} }
func doCancel(c *Conn) error { func doCancel(c *Conn) error {
network, address := c.config.networkAddress() network, address := c.BaseConn.Config.NetworkAddress()
cancelConn, err := c.config.Dial(network, address) cancelConn, err := c.BaseConn.Config.Dial(network, address)
if err != nil { if err != nil {
return err return err
} }

View File

@ -6,6 +6,7 @@ import (
"reflect" "reflect"
"time" "time"
"github.com/jackc/pgx/base"
"github.com/jackc/pgx/pgio" "github.com/jackc/pgx/pgio"
"github.com/jackc/pgx/pgtype" "github.com/jackc/pgx/pgtype"
) )
@ -78,32 +79,7 @@ func (fd FieldDescription) Type() reflect.Type {
} }
} }
// PgError represents an error reported by the PostgreSQL server. See type PgError = base.PgError
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
// detailed field description.
type PgError struct {
Severity string
Code string
Message string
Detail string
Hint string
Position int32
InternalPosition int32
InternalQuery string
Where string
SchemaName string
TableName string
ColumnName string
DataTypeName string
ConstraintName string
File string
Line int32
Routine string
}
func (pe PgError) Error() string {
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
}
// Notice represents a notice response message reported by the PostgreSQL // Notice represents a notice response message reported by the PostgreSQL
// server. Be aware that this is distinct from LISTEN/NOTIFY notification. // server. Be aware that this is distinct from LISTEN/NOTIFY notification.

5
v4.md
View File

@ -18,6 +18,11 @@ Potential Changes:
* Decouple connection pool from connections. Connection pool should be entirely replaceable. * Decouple connection pool from connections. Connection pool should be entirely replaceable.
* Decouple various logical layers of PostgreSQL connection such that an advanced user can choose what layer to work at and pgx still handles the lower level details. e.g Normal high level query level, PostgreSQL wire protocol message level, or wire byte level. * Decouple various logical layers of PostgreSQL connection such that an advanced user can choose what layer to work at and pgx still handles the lower level details. e.g Normal high level query level, PostgreSQL wire protocol message level, or wire byte level.
* Change prepared statement usage from using name as SQL text to specifically calling prepared statement (more like database/sql). * Change prepared statement usage from using name as SQL text to specifically calling prepared statement (more like database/sql).
* Remove stdlib hack for RegisterDriverConfig now that database/sql supports better way
Minor Potential Changes:
* Change PgError error implementation to pointer method
## Changes ## Changes