mirror of https://github.com/jackc/pgx.git
Merge remote-tracking branch 'pgconn/master' into v5-dev
commit
e2769993cc
|
@ -248,21 +248,21 @@ func ParseConfig(connString string) (*Config, error) {
|
||||||
config.LookupFunc = makeDefaultResolver().LookupHost
|
config.LookupFunc = makeDefaultResolver().LookupHost
|
||||||
|
|
||||||
notRuntimeParams := map[string]struct{}{
|
notRuntimeParams := map[string]struct{}{
|
||||||
"host": struct{}{},
|
"host": {},
|
||||||
"port": struct{}{},
|
"port": {},
|
||||||
"database": struct{}{},
|
"database": {},
|
||||||
"user": struct{}{},
|
"user": {},
|
||||||
"password": struct{}{},
|
"password": {},
|
||||||
"passfile": struct{}{},
|
"passfile": {},
|
||||||
"connect_timeout": struct{}{},
|
"connect_timeout": {},
|
||||||
"sslmode": struct{}{},
|
"sslmode": {},
|
||||||
"sslkey": struct{}{},
|
"sslkey": {},
|
||||||
"sslcert": struct{}{},
|
"sslcert": {},
|
||||||
"sslrootcert": struct{}{},
|
"sslrootcert": {},
|
||||||
"target_session_attrs": struct{}{},
|
"target_session_attrs": {},
|
||||||
"min_read_buffer_size": struct{}{},
|
"min_read_buffer_size": {},
|
||||||
"service": struct{}{},
|
"service": {},
|
||||||
"servicefile": struct{}{},
|
"servicefile": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range settings {
|
for k, v := range settings {
|
||||||
|
@ -329,10 +329,19 @@ func ParseConfig(connString string) (*Config, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if settings["target_session_attrs"] == "read-write" {
|
switch tsa := settings["target_session_attrs"]; tsa {
|
||||||
|
case "read-write":
|
||||||
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadWrite
|
||||||
} else if settings["target_session_attrs"] != "any" {
|
case "read-only":
|
||||||
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", settings["target_session_attrs"])}
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsReadOnly
|
||||||
|
case "primary":
|
||||||
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsPrimary
|
||||||
|
case "standby":
|
||||||
|
config.ValidateConnect = ValidateConnectTargetSessionAttrsStandby
|
||||||
|
case "any", "prefer-standby":
|
||||||
|
// do nothing
|
||||||
|
default:
|
||||||
|
return nil, &parseConfigError{connString: connString, msg: fmt.Sprintf("unknown target_session_attrs value: %v", tsa)}
|
||||||
}
|
}
|
||||||
|
|
||||||
return config, nil
|
return config, nil
|
||||||
|
@ -727,3 +736,48 @@ func ValidateConnectTargetSessionAttrsReadWrite(ctx context.Context, pgConn *PgC
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsReadOnly is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=read-only.
|
||||||
|
func ValidateConnectTargetSessionAttrsReadOnly(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "show transaction_read_only", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) != "on" {
|
||||||
|
return errors.New("connection is not read only")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsStandby is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=standby.
|
||||||
|
func ValidateConnectTargetSessionAttrsStandby(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) != "t" {
|
||||||
|
return errors.New("server is not in hot standby mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateConnectTargetSessionAttrsPrimary is an ValidateConnectFunc that implements libpq compatible
|
||||||
|
// target_session_attrs=primary.
|
||||||
|
func ValidateConnectTargetSessionAttrsPrimary(ctx context.Context, pgConn *PgConn) error {
|
||||||
|
result := pgConn.ExecParams(ctx, "select pg_is_in_recovery()", nil, nil, nil, nil).Read()
|
||||||
|
if result.Err != nil {
|
||||||
|
return result.Err
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(result.Rows[0][0]) == "t" {
|
||||||
|
return errors.New("server is in standby mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -541,7 +541,7 @@ func TestParseConfig(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "target_session_attrs",
|
name: "target_session_attrs read-write",
|
||||||
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write",
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-write",
|
||||||
config: &pgconn.Config{
|
config: &pgconn.Config{
|
||||||
User: "jack",
|
User: "jack",
|
||||||
|
@ -554,6 +554,87 @@ func TestParseConfig(t *testing.T) {
|
||||||
ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadWrite,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "target_session_attrs read-only",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=read-only",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsReadOnly,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "target_session_attrs primary",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=primary",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsPrimary,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "target_session_attrs standby",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=standby",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
ValidateConnect: pgconn.ValidateConnectTargetSessionAttrsStandby,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "target_session_attrs prefer-standby",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=prefer-standby",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "target_session_attrs any",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable&target_session_attrs=any",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "target_session_attrs not set (any)",
|
||||||
|
connString: "postgres://jack:secret@localhost:5432/mydb?sslmode=disable",
|
||||||
|
config: &pgconn.Config{
|
||||||
|
User: "jack",
|
||||||
|
Password: "secret",
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "mydb",
|
||||||
|
TLSConfig: nil,
|
||||||
|
RuntimeParams: map[string]string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, tt := range tests {
|
for i, tt := range tests {
|
||||||
|
|
|
@ -230,7 +230,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
|
network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port)
|
||||||
pgConn.conn, err = config.DialFunc(ctx, network, address)
|
netConn, err := config.DialFunc(ctx, network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var netErr net.Error
|
var netErr net.Error
|
||||||
if errors.As(err, &netErr) && netErr.Timeout() {
|
if errors.As(err, &netErr) && netErr.Timeout() {
|
||||||
|
@ -239,24 +239,27 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||||
return nil, &connectError{config: config, msg: "dial error", err: err}
|
return nil, &connectError{config: config, msg: "dial error", err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.parameterStatuses = make(map[string]string)
|
pgConn.conn = netConn
|
||||||
|
pgConn.contextWatcher = newContextWatcher(netConn)
|
||||||
|
pgConn.contextWatcher.Watch(ctx)
|
||||||
|
|
||||||
if fallbackConfig.TLSConfig != nil {
|
if fallbackConfig.TLSConfig != nil {
|
||||||
if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil {
|
tlsConn, err := startTLS(netConn, fallbackConfig.TLSConfig)
|
||||||
pgConn.conn.Close()
|
pgConn.contextWatcher.Unwatch() // Always unwatch `netConn` after TLS.
|
||||||
|
if err != nil {
|
||||||
|
netConn.Close()
|
||||||
return nil, &connectError{config: config, msg: "tls error", err: err}
|
return nil, &connectError{config: config, msg: "tls error", err: err}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pgConn.conn = tlsConn
|
||||||
|
pgConn.contextWatcher = newContextWatcher(tlsConn)
|
||||||
|
pgConn.contextWatcher.Watch(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.status = connStatusConnecting
|
|
||||||
pgConn.contextWatcher = ctxwatch.NewContextWatcher(
|
|
||||||
func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
|
|
||||||
func() { pgConn.conn.SetDeadline(time.Time{}) },
|
|
||||||
)
|
|
||||||
|
|
||||||
pgConn.contextWatcher.Watch(ctx)
|
|
||||||
defer pgConn.contextWatcher.Unwatch()
|
defer pgConn.contextWatcher.Unwatch()
|
||||||
|
|
||||||
|
pgConn.parameterStatuses = make(map[string]string)
|
||||||
|
pgConn.status = connStatusConnecting
|
||||||
pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn)
|
pgConn.frontend = config.BuildFrontend(pgConn.conn, pgConn.conn)
|
||||||
|
|
||||||
startupMsg := pgproto3.StartupMessage{
|
startupMsg := pgproto3.StartupMessage{
|
||||||
|
@ -332,7 +335,7 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return pgConn, nil
|
return pgConn, nil
|
||||||
case *pgproto3.ParameterStatus:
|
case *pgproto3.ParameterStatus, *pgproto3.NoticeResponse:
|
||||||
// handled by ReceiveMessage
|
// handled by ReceiveMessage
|
||||||
case *pgproto3.ErrorResponse:
|
case *pgproto3.ErrorResponse:
|
||||||
pgConn.conn.Close()
|
pgConn.conn.Close()
|
||||||
|
@ -344,24 +347,29 @@ func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) {
|
func newContextWatcher(conn net.Conn) *ctxwatch.ContextWatcher {
|
||||||
err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103})
|
return ctxwatch.NewContextWatcher(
|
||||||
|
func() { conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
|
||||||
|
func() { conn.SetDeadline(time.Time{}) },
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func startTLS(conn net.Conn, tlsConfig *tls.Config) (net.Conn, error) {
|
||||||
|
err := binary.Write(conn, binary.BigEndian, []int32{8, 80877103})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
response := make([]byte, 1)
|
response := make([]byte, 1)
|
||||||
if _, err = io.ReadFull(pgConn.conn, response); err != nil {
|
if _, err = io.ReadFull(conn, response); err != nil {
|
||||||
return
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if response[0] != 'S' {
|
if response[0] != 'S' {
|
||||||
return errors.New("server refused TLS connection")
|
return nil, errors.New("server refused TLS connection")
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.conn = tls.Client(pgConn.conn, tlsConfig)
|
return tls.Client(conn, tlsConfig), nil
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
|
func (pgConn *PgConn) txPasswordMessage(password string) (err error) {
|
||||||
|
@ -1709,10 +1717,7 @@ func Construct(hc *HijackedConn) (*PgConn, error) {
|
||||||
cleanupDone: make(chan struct{}),
|
cleanupDone: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
pgConn.contextWatcher = ctxwatch.NewContextWatcher(
|
pgConn.contextWatcher = newContextWatcher(pgConn.conn)
|
||||||
func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) },
|
|
||||||
func() { pgConn.conn.SetDeadline(time.Time{}) },
|
|
||||||
)
|
|
||||||
|
|
||||||
return pgConn, nil
|
return pgConn, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -161,6 +161,84 @@ func TestConnectTimeout(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnectTimeoutStuckOnTLSHandshake(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
connect func(connStr string) error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "via context that times out",
|
||||||
|
connect: func(connStr string) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10)
|
||||||
|
defer cancel()
|
||||||
|
_, err := pgconn.Connect(ctx, connStr)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "via config ConnectTimeout",
|
||||||
|
connect: func(connStr string) error {
|
||||||
|
conf, err := pgconn.ParseConfig(connStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
conf.ConnectTimeout = time.Millisecond * 10
|
||||||
|
_, err = pgconn.ConnectConfig(context.Background(), conf)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
tt := tt
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
serverErrChan := make(chan error)
|
||||||
|
defer close(serverErrChan)
|
||||||
|
go func() {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
serverErrChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
var buf []byte
|
||||||
|
_, err = conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
serverErrChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sleeping to hang the TLS handshake.
|
||||||
|
time.Sleep(time.Minute)
|
||||||
|
}()
|
||||||
|
|
||||||
|
parts := strings.Split(ln.Addr().String(), ":")
|
||||||
|
host := parts[0]
|
||||||
|
port := parts[1]
|
||||||
|
connStr := fmt.Sprintf("host=%s port=%s", host, port)
|
||||||
|
|
||||||
|
errChan := make(chan error)
|
||||||
|
go func() {
|
||||||
|
err := tt.connect(connStr)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-errChan:
|
||||||
|
require.True(t, pgconn.Timeout(err), err)
|
||||||
|
case err = <-serverErrChan:
|
||||||
|
t.Fatalf("server failed with error: %s", err)
|
||||||
|
case <-time.After(time.Millisecond * 100):
|
||||||
|
t.Fatal("exceeded connection timeout without erroring out")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnectInvalidUser(t *testing.T) {
|
func TestConnectInvalidUser(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -1220,6 +1298,7 @@ func TestConnOnNotice(t *testing.T) {
|
||||||
config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) {
|
config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) {
|
||||||
msg = notice.Message
|
msg = notice.Message
|
||||||
}
|
}
|
||||||
|
config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the message we expect.
|
||||||
|
|
||||||
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -1876,7 +1955,11 @@ func TestConnSendBytesAndReceiveMessage(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_CONN_STRING"))
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_CONN_STRING"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the messages we expect.
|
||||||
|
|
||||||
|
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer closeConn(t, pgConn)
|
defer closeConn(t, pgConn)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue