mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
Initial base.Connect extraction
This commit is contained in:
parent
06fb816b71
commit
65e69c5580
234
base/conn.go
234
base/conn.go
@ -1,11 +1,106 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"crypto/tls"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/pgproto3"
|
||||
)
|
||||
|
||||
// PgError represents an error reported by the PostgreSQL server. See
|
||||
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
|
||||
// detailed field description.
|
||||
type PgError struct {
|
||||
Severity string
|
||||
Code string
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position int32
|
||||
InternalPosition int32
|
||||
InternalQuery string
|
||||
Where string
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
DataTypeName string
|
||||
ConstraintName string
|
||||
File string
|
||||
Line int32
|
||||
Routine string
|
||||
}
|
||||
|
||||
func (pe PgError) Error() string {
|
||||
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||
}
|
||||
|
||||
// DialFunc is a function that can be used to connect to a PostgreSQL server
|
||||
type DialFunc func(network, addr string) (net.Conn, error)
|
||||
|
||||
// ErrTLSRefused occurs when the connection attempt requires TLS and the
|
||||
// PostgreSQL server refuses to use TLS
|
||||
var ErrTLSRefused = errors.New("server refused TLS connection")
|
||||
|
||||
type ConnConfig struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16 // default: 5432
|
||||
Database string
|
||||
User string // default: OS user name
|
||||
Password string
|
||||
TLSConfig *tls.Config // config for TLS connection -- nil disables TLS
|
||||
Dial DialFunc
|
||||
RuntimeParams map[string]string // Run-time parameters to set on connection as session default values (e.g. search_path or application_name)
|
||||
}
|
||||
|
||||
func (cc *ConnConfig) NetworkAddress() (network, address string) {
|
||||
// If host is a valid path, then address is unix socket
|
||||
if _, err := os.Stat(cc.Host); err == nil {
|
||||
network = "unix"
|
||||
address = cc.Host
|
||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
|
||||
}
|
||||
} else {
|
||||
network = "tcp"
|
||||
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
|
||||
}
|
||||
|
||||
return network, address
|
||||
}
|
||||
|
||||
func (cc *ConnConfig) assignDefaults() error {
|
||||
if cc.User == "" {
|
||||
user, err := user.Current()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.User = user.Username
|
||||
}
|
||||
|
||||
if cc.Port == 0 {
|
||||
cc.Port = 5432
|
||||
}
|
||||
|
||||
if cc.Dial == nil {
|
||||
defaultDialer := &net.Dialer{KeepAlive: 5 * time.Minute}
|
||||
cc.Dial = defaultDialer.Dial
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Conn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
type Conn struct {
|
||||
NetConn net.Conn // the underlying TCP or unix domain socket connection
|
||||
@ -14,6 +109,145 @@ type Conn struct {
|
||||
RuntimeParams map[string]string // parameters that have been reported by the server
|
||||
TxStatus byte
|
||||
Frontend *pgproto3.Frontend
|
||||
|
||||
Config ConnConfig
|
||||
}
|
||||
|
||||
func Connect(cc ConnConfig) (*Conn, error) {
|
||||
err := cc.assignDefaults()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn := new(Conn)
|
||||
conn.Config = cc
|
||||
|
||||
conn.NetConn, err = cc.Dial(cc.NetworkAddress())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn.RuntimeParams = make(map[string]string)
|
||||
|
||||
if cc.TLSConfig != nil {
|
||||
if err := conn.startTLS(cc.TLSConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
conn.Frontend, err = pgproto3.NewFrontend(conn.NetConn, conn.NetConn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startupMsg := pgproto3.StartupMessage{
|
||||
ProtocolVersion: pgproto3.ProtocolVersionNumber,
|
||||
Parameters: make(map[string]string),
|
||||
}
|
||||
|
||||
// Copy default run-time params
|
||||
for k, v := range cc.RuntimeParams {
|
||||
startupMsg.Parameters[k] = v
|
||||
}
|
||||
|
||||
startupMsg.Parameters["user"] = cc.User
|
||||
if cc.Database != "" {
|
||||
startupMsg.Parameters["database"] = cc.Database
|
||||
}
|
||||
|
||||
if _, err := conn.NetConn.Write(startupMsg.Encode(nil)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
msg, err := conn.ReceiveMessage()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.BackendKeyData:
|
||||
conn.PID = msg.ProcessID
|
||||
conn.SecretKey = msg.SecretKey
|
||||
case *pgproto3.Authentication:
|
||||
if err = conn.rxAuthenticationX(msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case *pgproto3.ReadyForQuery:
|
||||
return conn, nil
|
||||
case *pgproto3.ParameterStatus:
|
||||
// handled by ReceiveMessage
|
||||
case *pgproto3.ErrorResponse:
|
||||
return nil, PgError{
|
||||
Severity: msg.Severity,
|
||||
Code: msg.Code,
|
||||
Message: msg.Message,
|
||||
Detail: msg.Detail,
|
||||
Hint: msg.Hint,
|
||||
Position: msg.Position,
|
||||
InternalPosition: msg.InternalPosition,
|
||||
InternalQuery: msg.InternalQuery,
|
||||
Where: msg.Where,
|
||||
SchemaName: msg.SchemaName,
|
||||
TableName: msg.TableName,
|
||||
ColumnName: msg.ColumnName,
|
||||
DataTypeName: msg.DataTypeName,
|
||||
ConstraintName: msg.ConstraintName,
|
||||
File: msg.File,
|
||||
Line: msg.Line,
|
||||
Routine: msg.Routine,
|
||||
}
|
||||
default:
|
||||
return nil, errors.New("unexpected message")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) startTLS(tlsConfig *tls.Config) (err error) {
|
||||
err = binary.Write(conn.NetConn, binary.BigEndian, []int32{8, 80877103})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
response := make([]byte, 1)
|
||||
if _, err = io.ReadFull(conn.NetConn, response); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if response[0] != 'S' {
|
||||
return ErrTLSRefused
|
||||
}
|
||||
|
||||
conn.NetConn = tls.Client(conn.NetConn, tlsConfig)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) {
|
||||
switch msg.Type {
|
||||
case pgproto3.AuthTypeOk:
|
||||
case pgproto3.AuthTypeCleartextPassword:
|
||||
err = c.txPasswordMessage(c.Config.Password)
|
||||
case pgproto3.AuthTypeMD5Password:
|
||||
digestedPassword := "md5" + hexMD5(hexMD5(c.Config.Password+c.Config.User)+string(msg.Salt[:]))
|
||||
err = c.txPasswordMessage(digestedPassword)
|
||||
default:
|
||||
err = errors.New("Received unknown authentication message")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (conn *Conn) txPasswordMessage(password string) (err error) {
|
||||
msg := &pgproto3.PasswordMessage{Password: password}
|
||||
_, err = conn.NetConn.Write(msg.Encode(nil))
|
||||
return err
|
||||
}
|
||||
|
||||
func hexMD5(s string) string {
|
||||
hash := md5.New()
|
||||
io.WriteString(hash, s)
|
||||
return hex.EncodeToString(hash.Sum(nil))
|
||||
}
|
||||
|
||||
func (conn *Conn) ReceiveMessage() (pgproto3.BackendMessage, error) {
|
||||
|
152
conn.go
152
conn.go
@ -13,8 +13,6 @@ import (
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@ -92,27 +90,11 @@ type ConnConfig struct {
|
||||
PreferSimpleProtocol bool
|
||||
}
|
||||
|
||||
func (cc *ConnConfig) networkAddress() (network, address string) {
|
||||
network = "tcp"
|
||||
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
|
||||
// See if host is a valid path, if yes connect with a socket
|
||||
if _, err := os.Stat(cc.Host); err == nil {
|
||||
// For backward compatibility accept socket file paths -- but directories are now preferred
|
||||
network = "unix"
|
||||
address = cc.Host
|
||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
|
||||
}
|
||||
}
|
||||
|
||||
return network, address
|
||||
}
|
||||
|
||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
// Use ConnPool to manage access to multiple database connections from multiple
|
||||
// goroutines.
|
||||
type Conn struct {
|
||||
BaseConn base.Conn
|
||||
BaseConn *base.Conn
|
||||
wbuf []byte
|
||||
config ConnConfig // config used when establishing this connection
|
||||
preparedStatements map[string]*PreparedStatement
|
||||
@ -196,7 +178,7 @@ var ErrDeadConn = errors.New("conn is dead")
|
||||
|
||||
// ErrTLSRefused occurs when the connection attempt requires TLS and the
|
||||
// PostgreSQL server refuses to use TLS
|
||||
var ErrTLSRefused = errors.New("server refused TLS connection")
|
||||
var ErrTLSRefused = base.ErrTLSRefused
|
||||
|
||||
// ErrConnBusy occurs when the connection is busy (for example, in the middle of
|
||||
// reading query results) and another action is attempted.
|
||||
@ -237,41 +219,17 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
|
||||
}
|
||||
c.logger = c.config.Logger
|
||||
|
||||
if c.config.User == "" {
|
||||
user, err := user.Current()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.config.User = user.Username
|
||||
if c.shouldLog(LogLevelDebug) {
|
||||
c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"User": c.config.User})
|
||||
}
|
||||
}
|
||||
|
||||
if c.config.Port == 0 {
|
||||
c.config.Port = 5432
|
||||
if c.shouldLog(LogLevelDebug) {
|
||||
c.log(LogLevelDebug, "Using default connection config", map[string]interface{}{"Port": c.config.Port})
|
||||
}
|
||||
}
|
||||
|
||||
c.onNotice = config.OnNotice
|
||||
|
||||
network, address := c.config.networkAddress()
|
||||
if c.config.Dial == nil {
|
||||
d := defaultDialer()
|
||||
c.config.Dial = d.Dial
|
||||
}
|
||||
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address})
|
||||
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Host})
|
||||
}
|
||||
err = c.connect(config, network, address, config.TLSConfig)
|
||||
err = c.connect(config, config.TLSConfig)
|
||||
if err != nil && config.UseFallbackTLS {
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err})
|
||||
}
|
||||
err = c.connect(config, network, address, config.FallbackTLSConfig)
|
||||
err = c.connect(config, config.FallbackTLSConfig)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@ -284,9 +242,19 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
|
||||
c.BaseConn = base.Conn{}
|
||||
c.BaseConn.NetConn, err = c.config.Dial(network, address)
|
||||
func (c *Conn) connect(config ConnConfig, tlsConfig *tls.Config) (err error) {
|
||||
cc := base.ConnConfig{
|
||||
Host: config.Host,
|
||||
Port: config.Port,
|
||||
Database: config.Database,
|
||||
User: config.User,
|
||||
Password: config.Password,
|
||||
TLSConfig: tlsConfig,
|
||||
Dial: base.DialFunc(config.Dial),
|
||||
RuntimeParams: config.RuntimeParams,
|
||||
}
|
||||
|
||||
c.BaseConn, err = base.Connect(cc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -299,7 +267,6 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
||||
}
|
||||
}()
|
||||
|
||||
c.BaseConn.RuntimeParams = make(map[string]string)
|
||||
c.preparedStatements = make(map[string]*PreparedStatement)
|
||||
c.channels = make(map[string]struct{})
|
||||
c.cancelQueryCompleted = make(chan struct{})
|
||||
@ -312,81 +279,20 @@ func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tl
|
||||
c.status = connStatusIdle
|
||||
c.mux.Unlock()
|
||||
|
||||
if tlsConfig != nil {
|
||||
if c.shouldLog(LogLevelDebug) {
|
||||
c.log(LogLevelDebug, "starting TLS handshake", nil)
|
||||
}
|
||||
if err := c.startTLS(tlsConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
// Replication connections can't execute the queries to
|
||||
// populate the c.PgTypes and c.pgsqlAfInet
|
||||
if _, ok := config.RuntimeParams["replication"]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.BaseConn.Frontend, err = pgproto3.NewFrontend(c.BaseConn.NetConn, c.BaseConn.NetConn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startupMsg := pgproto3.StartupMessage{
|
||||
ProtocolVersion: pgproto3.ProtocolVersionNumber,
|
||||
Parameters: make(map[string]string),
|
||||
}
|
||||
|
||||
// Copy default run-time params
|
||||
for k, v := range config.RuntimeParams {
|
||||
startupMsg.Parameters[k] = v
|
||||
}
|
||||
|
||||
startupMsg.Parameters["user"] = c.config.User
|
||||
if c.config.Database != "" {
|
||||
startupMsg.Parameters["database"] = c.config.Database
|
||||
}
|
||||
|
||||
if _, err := c.BaseConn.NetConn.Write(startupMsg.Encode(nil)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.pendingReadyForQueryCount = 1
|
||||
|
||||
for {
|
||||
msg, err := c.rxMsg()
|
||||
if c.ConnInfo == minimalConnInfo {
|
||||
err = c.initConnInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.BackendKeyData:
|
||||
c.BaseConn.PID = msg.ProcessID
|
||||
c.BaseConn.SecretKey = msg.SecretKey
|
||||
case *pgproto3.Authentication:
|
||||
if err = c.rxAuthenticationX(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
case *pgproto3.ReadyForQuery:
|
||||
c.rxReadyForQuery(msg)
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "connection established", nil)
|
||||
}
|
||||
|
||||
// Replication connections can't execute the queries to
|
||||
// populate the c.PgTypes and c.pgsqlAfInet
|
||||
if _, ok := config.RuntimeParams["replication"]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.ConnInfo == minimalConnInfo {
|
||||
err = c.initConnInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
default:
|
||||
if err = c.processContextFreeMsg(msg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) {
|
||||
@ -1609,7 +1515,7 @@ func (c *Conn) log(lvl LogLevel, msg string, data map[string]interface{}) {
|
||||
if data == nil {
|
||||
data = map[string]interface{}{}
|
||||
}
|
||||
if c.BaseConn.PID != 0 {
|
||||
if c.BaseConn != nil && c.BaseConn.PID != 0 {
|
||||
data["pid"] = c.BaseConn.PID
|
||||
}
|
||||
|
||||
@ -1641,8 +1547,8 @@ func quoteIdentifier(s string) string {
|
||||
}
|
||||
|
||||
func doCancel(c *Conn) error {
|
||||
network, address := c.config.networkAddress()
|
||||
cancelConn, err := c.config.Dial(network, address)
|
||||
network, address := c.BaseConn.Config.NetworkAddress()
|
||||
cancelConn, err := c.BaseConn.Config.Dial(network, address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
28
messages.go
28
messages.go
@ -6,6 +6,7 @@ import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/base"
|
||||
"github.com/jackc/pgx/pgio"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
)
|
||||
@ -78,32 +79,7 @@ func (fd FieldDescription) Type() reflect.Type {
|
||||
}
|
||||
}
|
||||
|
||||
// PgError represents an error reported by the PostgreSQL server. See
|
||||
// http://www.postgresql.org/docs/9.3/static/protocol-error-fields.html for
|
||||
// detailed field description.
|
||||
type PgError struct {
|
||||
Severity string
|
||||
Code string
|
||||
Message string
|
||||
Detail string
|
||||
Hint string
|
||||
Position int32
|
||||
InternalPosition int32
|
||||
InternalQuery string
|
||||
Where string
|
||||
SchemaName string
|
||||
TableName string
|
||||
ColumnName string
|
||||
DataTypeName string
|
||||
ConstraintName string
|
||||
File string
|
||||
Line int32
|
||||
Routine string
|
||||
}
|
||||
|
||||
func (pe PgError) Error() string {
|
||||
return pe.Severity + ": " + pe.Message + " (SQLSTATE " + pe.Code + ")"
|
||||
}
|
||||
type PgError = base.PgError
|
||||
|
||||
// Notice represents a notice response message reported by the PostgreSQL
|
||||
// server. Be aware that this is distinct from LISTEN/NOTIFY notification.
|
||||
|
5
v4.md
5
v4.md
@ -18,6 +18,11 @@ Potential Changes:
|
||||
* Decouple connection pool from connections. Connection pool should be entirely replaceable.
|
||||
* Decouple various logical layers of PostgreSQL connection such that an advanced user can choose what layer to work at and pgx still handles the lower level details. e.g Normal high level query level, PostgreSQL wire protocol message level, or wire byte level.
|
||||
* Change prepared statement usage from using name as SQL text to specifically calling prepared statement (more like database/sql).
|
||||
* Remove stdlib hack for RegisterDriverConfig now that database/sql supports better way
|
||||
|
||||
Minor Potential Changes:
|
||||
|
||||
* Change PgError error implementation to pointer method
|
||||
|
||||
## Changes
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user